From b11450dedc59b117ad23426b47f2465c459ea62a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Jul 2020 08:48:58 -0400 Subject: [PATCH] Convert E2E key and room key handlers to async/await. (#7851) --- changelog.d/7851.misc | 1 + synapse/handlers/e2e_keys.py | 147 +++++------ synapse/handlers/e2e_room_keys.py | 75 +++--- tests/handlers/test_e2e_keys.py | 288 +++++++++++++-------- tests/handlers/test_e2e_room_keys.py | 373 ++++++++++++++++++--------- 5 files changed, 522 insertions(+), 362 deletions(-) create mode 100644 changelog.d/7851.misc diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc new file mode 100644 index 000000000..e5cf540ed --- /dev/null +++ b/changelog.d/7851.misc @@ -0,0 +1 @@ +Convert E2E keys and room keys handlers to async/await. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a7e60cbc2..361dd64cd 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -77,8 +77,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -124,7 +123,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -142,7 +141,7 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache(query_list) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -161,14 +160,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -192,7 +190,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -201,11 +199,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -227,7 +225,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -251,7 +249,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -267,8 +265,7 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database Args: @@ -289,7 +286,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -315,8 +312,7 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices(self, query): """Get E2E device keys for local users Args: @@ -354,7 +350,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -364,16 +360,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -382,8 +377,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -399,7 +393,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -411,12 +405,11 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -429,7 +422,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -454,9 +447,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -477,12 +469,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -494,7 +486,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -507,15 +499,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -533,7 +524,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -556,10 +547,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -574,7 +564,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -613,10 +603,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -626,23 +616,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -667,13 +656,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -681,21 +670,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -728,7 +716,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -738,12 +726,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -853,8 +841,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -882,7 +869,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -905,7 +892,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -958,8 +945,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -983,7 +969,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1009,15 +995,14 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, ): """Queries cross-signing keys for a remote user and saves them to the database @@ -1035,7 +1020,7 @@ class E2eKeysHandler(object): If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1101,14 +1086,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f55470a70..0bb983dc2 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together @@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1acf287ca..cdd093ffa 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: @@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, + }, + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} - }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} - }, - }, - }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 822ea42dd..3362050ce 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}})