Add support for claiming multiple OTKs at once. (#15468)

MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
This commit is contained in:
Patrick Cloke 2023-04-27 12:57:46 -04:00 committed by GitHub
parent 6efa674004
commit 57aeeb308b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 271 additions and 98 deletions

1
changelog.d/15468.misc Normal file
View File

@ -0,0 +1 @@
Support claiming more than one OTK at a time.

View File

@ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient):
return False return False
async def claim_client_keys( async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]] self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: ) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from an application service. """Claim one time keys from an application service.
Note that any error (including a timeout) is treated as the application Note that any error (including a timeout) is treated as the application
@ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient):
# Create the expected payload shape. # Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {} body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query: for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
)
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try: try:
@ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient):
# or if some are still missing. # or if some are still missing.
# #
# TODO This places a lot of faith in the response shape being correct. # TODO This places a lot of faith in the response shape being correct.
missing = [ missing = []
(user_id, device, algorithm) for user_id, device, algorithm, count in query:
for user_id, device, algorithm in query # Count the number of keys in the response for this algorithm by
if algorithm not in response.get(user_id, {}).get(device, []) # checking which key IDs start with the algorithm. This uses that
] # True == 1 in Python to generate a count.
response_count = sum(
key_id.startswith(f"{algorithm}:")
for key_id in response.get(user_id, {}).get(device, {})
)
count -= response_count
# If the appservice responds with fewer keys than requested, then
# consider the request unfulfilled.
if count > 0:
missing.append((user_id, device, algorithm, count))
return response, missing return response, missing

View File

@ -235,7 +235,10 @@ class FederationClient(FederationBase):
) )
async def claim_client_keys( async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: Optional[int] self,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict: ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
@ -247,6 +250,50 @@ class FederationClient(FederationBase):
The JSON object from the response The JSON object from the response
""" """
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
use_unstable = False
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
if any(count > 1 for count in algorithms.values()):
use_unstable = True
if algorithms:
# For the stable query, choose only the first algorithm.
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
# For the unstable query, repeat each algorithm by count, then
# splat those into chain to get a flattened list of all algorithms.
#
# Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
unstable_content.setdefault(user_id, {})[device_id] = list(
itertools.chain(
*(
itertools.repeat(algorithm, count)
for algorithm, count in algorithms.items()
)
)
)
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
)
else:
logger.debug("Skipping unstable claim client keys API")
return await self.transport_layer.claim_client_keys( return await self.transport_layer.claim_client_keys(
destination, content, timeout destination, content, timeout
) )

View File

@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
@trace @trace
async def on_claim_client_keys( async def on_claim_client_keys(
self, origin: str, content: JsonDict, always_include_fallback_keys: bool self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]: ) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys( results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys query, always_include_fallback_keys=always_include_fallback_keys

View File

@ -650,10 +650,10 @@ class TransportLayerClient:
Response: Response:
{ {
"device_keys": { "one_time_keys": {
"<user_id>": { "<user_id>": {
"<device_id>": { "<device_id>": {
"<algorithm>:<key_id>": "<key_base64>" "<algorithm>:<key_id>": <OTK JSON>
} }
} }
} }
@ -669,7 +669,50 @@ class TransportLayerClient:
path = _create_v1_path("/user/keys/claim") path = _create_v1_path("/user/keys/claim")
return await self.client.post_json( return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)
async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {"<algorithm>": <count>}
}
}
}
Response:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": <OTK JSON>
}
}
}
}
Args:
destination: The server to query.
query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
return await self.client.post_json(
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
) )
async def get_missing_events( async def get_missing_events(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from collections import Counter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict, Dict,
@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST( async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
key_query.append((user_id, device_id, algorithm, 1))
response = await self.handler.on_claim_client_keys( response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False key_query, always_include_fallback_keys=False
) )
return 200, response return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
""" """
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it Identical to the stable endpoint (FederationClientKeysClaimServlet) except
always includes fallback keys in the response. it allows for querying for multiple OTKs at once and always includes fallback
keys in the response.
""" """
PREFIX = FEDERATION_UNSTABLE_PREFIX PREFIX = FEDERATION_UNSTABLE_PREFIX
@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST( async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
counts = Counter(algorithms)
for algorithm, count in counts.items():
key_query.append((user_id, device_id, algorithm, count))
response = await self.handler.on_claim_client_keys( response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True key_query, always_include_fallback_keys=True
) )
return 200, response return 200, response
@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet, FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
FederationVersionServlet, FederationVersionServlet,

View File

@ -841,8 +841,10 @@ class ApplicationServicesHandler:
return True return True
async def claim_e2e_one_time_keys( async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]] self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: ) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from application services. """Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a Users which are exclusively owned by an application service are sent a
@ -863,18 +865,18 @@ class ApplicationServicesHandler:
services = self.store.get_app_services() services = self.store.get_app_services()
# Partition the users by appservice. # Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = [] missing = []
for user_id, device, algorithm in query: for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id): if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm)) missing.append((user_id, device, algorithm, count))
continue continue
# Find the associated appservice. # Find the associated appservice.
for service in services: for service in services:
if service.is_exclusive_user(user_id): if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append( query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm) (user_id, device, algorithm, count)
) )
continue continue

View File

@ -564,7 +564,7 @@ class E2eKeysHandler:
async def claim_local_one_time_keys( async def claim_local_one_time_keys(
self, self,
local_query: List[Tuple[str, str, str]], local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool, always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users. """Claim one time keys for local users.
@ -581,6 +581,12 @@ class E2eKeysHandler:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
""" """
# Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [
(user_id, device_id, algorithm, min(count, 5))
for user_id, device_id, algorithm, count in local_query
]
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S # If the application services have not provided any keys via the C-S
@ -607,7 +613,7 @@ class E2eKeysHandler:
# from the appservice for that user ID / device ID. If it is found, # from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a # check if any of the keys match the requested algorithm & are a
# fallback key. # fallback key.
for user_id, device_id, algorithm in local_query: for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query. # Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {}) as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False found_otk = False
@ -630,13 +636,17 @@ class E2eKeysHandler:
.get(device_id, {}) .get(device_id, {})
.keys() .keys()
) )
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used)) fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else: else:
# All fallback keys get marked as used. # All fallback keys get marked as used.
fallback_query = [ fallback_query = [
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True) (user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found for user_id, device_id, algorithm, count in not_found
] ]
# For each user that does not have a one-time keys available, see if # For each user that does not have a one-time keys available, see if
@ -650,18 +660,19 @@ class E2eKeysHandler:
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(
self, self,
query: Dict[str, Dict[str, Dict[str, str]]], query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int], timeout: Optional[int],
always_include_fallback_keys: bool, always_include_fallback_keys: bool,
) -> JsonDict: ) -> JsonDict:
local_query: List[Tuple[str, str, str]] = [] local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items(): for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in one_time_keys.items(): for device_id, algorithms in one_time_keys.items():
local_query.append((user_id, device_id, algorithm)) for algorithm, count in algorithms.items():
local_query.append((user_id, device_id, algorithm, count))
else: else:
domain = get_domain_from_id(user_id) domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@ -692,7 +703,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
remote_result = await self.federation.claim_client_keys( remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout destination, device_keys, timeout=timeout
) )
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys: if user_id in device_keys:

View File

@ -16,7 +16,8 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Any, Optional, Tuple from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# Generate a count for each algorithm, which is hard-coded to 1.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in one_time_keys.items():
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
result = await self.e2e_keys_handler.claim_one_time_keys( result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False query, timeout, always_include_fallback_keys=False
) )
return 200, result return 200, result
class UnstableOneTimeKeyServlet(RestServlet): class UnstableOneTimeKeyServlet(RestServlet):
""" """
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
fallback keys in the response. querying for multiple OTKs at once and always includes fallback keys in the
response.
POST /keys/claim HTTP/1.1
{
"one_time_keys": {
"<user_id>": {
"<device_id>": ["<algorithm>", ...]
} } }
HTTP/1.1 200 OK
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
""" """
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# Generate a count for each algorithm.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithms in one_time_keys.items():
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
result = await self.e2e_keys_handler.claim_one_time_keys( result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True query, timeout, always_include_fallback_keys=True
) )
return 200, result return 200, result

View File

@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
... ...
async def claim_e2e_one_time_keys( async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]] self, query_list: Iterable[Tuple[str, str, str, int]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: ) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Take a list of one time keys out of the database. """Take a list of one time keys out of the database.
Args: Args:
@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@trace @trace
def _claim_e2e_one_time_key_simple( def _claim_e2e_one_time_key_simple(
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str txn: LoggingTransaction,
) -> Optional[Tuple[str, str]]: user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING. """Claim OTK for device for DBs that don't support RETURNING.
Returns: Returns:
@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
sql = """ sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ? WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT 1 LIMIT ?
""" """
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm, count))
otk_row = txn.fetchone() otk_rows = list(txn)
if otk_row is None: if not otk_rows:
return None return []
key_id, key_json = otk_row self.db_pool.simple_delete_many_txn(
self.db_pool.simple_delete_one_txn(
txn, txn,
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"algorithm": algorithm, "algorithm": algorithm,
"key_id": key_id,
}, },
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
return f"{algorithm}:{key_id}", key_json return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
@trace @trace
def _claim_e2e_one_time_key_returning( def _claim_e2e_one_time_key_returning(
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str txn: LoggingTransaction,
) -> Optional[Tuple[str, str]]: user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING. """Claim OTK for device for DBs that support RETURNING.
Returns: Returns:
@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
AND key_id IN ( AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ? WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT 1 LIMIT ?
) )
RETURNING key_id, key_json RETURNING key_id, key_json
""" """
txn.execute( txn.execute(
sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
) )
otk_row = txn.fetchone() otk_rows = list(txn)
if otk_row is None: if not otk_rows:
return None return []
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
key_id, key_json = otk_row return [
return f"{algorithm}:{key_id}", key_json (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str]] = [] missing: List[Tuple[str, str, str, int]] = []
for user_id, device_id, algorithm in query_list: for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning: if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that # If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode. # allows us to use autocommit mode.
@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False db_autocommit = False
claim_row = await self.db_pool.runInteraction( claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", "claim_e2e_one_time_keys",
_claim_e2e_one_time_key, _claim_e2e_one_time_key,
user_id, user_id,
device_id, device_id,
algorithm, algorithm,
count,
db_autocommit=db_autocommit, db_autocommit=db_autocommit,
) )
if claim_row: if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault( device_results = results.setdefault(user_id, {}).setdefault(
device_id, {} device_id, {}
) )
device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) for claim_row in claim_rows:
else: device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
missing.append((user_id, device_id, algorithm)) # Did we get enough OTKs?
count -= len(claim_rows)
if count:
missing.append((user_id, device_id, algorithm, count))
return results, missing return results, missing

View File

@ -195,11 +195,11 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
MISSING_KEYS = [ MISSING_KEYS = [
# Known user, known device, missing algorithm. # Known user, known device, missing algorithm.
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), ("@alice:example.org", "DEVICE_2", "xyz", 1),
# Known user, missing device. # Known user, missing device.
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), ("@alice:example.org", "DEVICE_3", "signed_curve25519", 1),
# Unknown user. # Unknown user.
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), ("@bob:example.org", "DEVICE_4", "signed_curve25519", 1),
] ]
claimed_keys, missing = self.get_success( claimed_keys, missing = self.get_success(
@ -207,9 +207,8 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.service, self.service,
[ [
# Found devices # Found devices
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), ("@alice:example.org", "DEVICE_1", "signed_curve25519", 1),
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), ("@alice:example.org", "DEVICE_2", "signed_curve25519", 1),
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
] ]
+ MISSING_KEYS, + MISSING_KEYS,
) )

View File

@ -160,7 +160,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success( res2 = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -205,7 +205,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# key # key
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -224,7 +224,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# claiming an OTK again should return the same fallback key # claiming an OTK again should return the same fallback key
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -273,7 +273,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -285,7 +285,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -306,7 +306,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -347,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# return both. # return both.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -369,7 +369,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claiming an OTK again should return only the fallback key. # Claiming an OTK again should return only the fallback key.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, {local_user: {device_id: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1052,7 +1052,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Setup a response, but only for device 2. # Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable( self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
) )
# we shouldn't have any unused fallback keys yet # we shouldn't have any unused fallback keys yet
@ -1079,11 +1079,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# query the fallback keys. # query the fallback keys.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{ {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
"one_time_keys": {
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
}
},
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -1128,7 +1124,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will ask the appservice and do nothing else. # Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}}, {local_user: {device_id_1: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1172,7 +1168,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# uploaded fallback key. # uploaded fallback key.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}}, {local_user: {device_id_1: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1205,7 +1201,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will return information only from the database. # Claim OTKs, which will return information only from the database.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}}, {local_user: {device_id_1: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1232,7 +1228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will return only the fallback key from the database. # Claim OTKs, which will return only the fallback key from the database.
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}}, {local_user: {device_id_1: {"alg1": 1}}},
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )