Convert E2E key and room key handlers to async/await. (#7851)

This commit is contained in:
Patrick Cloke 2020-07-15 08:48:58 -04:00 committed by GitHub
parent 111e70d75c
commit b11450dedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 522 additions and 362 deletions

1
changelog.d/7851.misc Normal file
View File

@ -0,0 +1 @@
Convert E2E keys and room keys handlers to async/await.

View File

@ -77,8 +77,7 @@ class E2eKeysHandler(object):
) )
@trace @trace
@defer.inlineCallbacks async def query_devices(self, query_body, timeout, from_user_id):
def query_devices(self, query_body, timeout, from_user_id):
""" Handle a device key query from a client """ Handle a device key query from a client
{ {
@ -124,7 +123,7 @@ class E2eKeysHandler(object):
failures = {} failures = {}
results = {} results = {}
if local_query: if local_query:
local_result = yield self.query_local_devices(local_query) local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items(): for user_id, keys in local_result.items():
if user_id in local_query: if user_id in local_query:
results[user_id] = keys results[user_id] = keys
@ -142,7 +141,7 @@ class E2eKeysHandler(object):
( (
user_ids_not_in_cache, user_ids_not_in_cache,
remote_results, remote_results,
) = yield self.store.get_user_devices_from_cache(query_list) ) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items(): for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {}) user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items(): for device_id, device in devices.items():
@ -161,14 +160,13 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id] r[user_id] = remote_queries[user_id]
# Get cached cross-signing keys # Get cached cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache( cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id device_keys_query, from_user_id
) )
# Now fetch any devices that we don't have in our cache # Now fetch any devices that we don't have in our cache
@trace @trace
@defer.inlineCallbacks async def do_remote_query(destination):
def do_remote_query(destination):
"""This is called when we are querying the device list of a user on """This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for cache. If we share a room with this user and we're not querying for
@ -192,7 +190,7 @@ class E2eKeysHandler(object):
if device_list: if device_list:
continue continue
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids: if not room_ids:
continue continue
@ -201,11 +199,11 @@ class E2eKeysHandler(object):
# done an initial sync on the device list so we do it now. # done an initial sync on the device list so we do it now.
try: try:
if self._is_master: if self._is_master:
user_devices = yield self.device_handler.device_list_updater.user_device_resync( user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id user_id
) )
else: else:
user_devices = yield self._user_device_resync_client( user_devices = await self._user_device_resync_client(
user_id=user_id user_id=user_id
) )
@ -227,7 +225,7 @@ class E2eKeysHandler(object):
destination_query.pop(user_id) destination_query.pop(user_id)
try: try:
remote_result = yield self.federation.query_client_keys( remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout destination, {"device_keys": destination_query}, timeout=timeout
) )
@ -251,7 +249,7 @@ class E2eKeysHandler(object):
set_tag("error", True) set_tag("error", True)
set_tag("reason", failure) set_tag("reason", failure)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(do_remote_query, destination) run_in_background(do_remote_query, destination)
@ -267,8 +265,7 @@ class E2eKeysHandler(object):
return ret return ret
@defer.inlineCallbacks async def get_cross_signing_keys_from_cache(self, query, from_user_id):
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database """Get cross-signing keys for users from the database
Args: Args:
@ -289,7 +286,7 @@ class E2eKeysHandler(object):
user_ids = list(query) user_ids = list(query)
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
for user_id, user_info in keys.items(): for user_id, user_info in keys.items():
if user_info is None: if user_info is None:
@ -315,8 +312,7 @@ class E2eKeysHandler(object):
} }
@trace @trace
@defer.inlineCallbacks async def query_local_devices(self, query):
def query_local_devices(self, query):
"""Get E2E device keys for local users """Get E2E device keys for local users
Args: Args:
@ -354,7 +350,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict # make sure that each queried user appears in the result dict
result_dict[user_id] = {} result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query) results = await self.store.get_e2e_device_keys(local_query)
# Build the result structure # Build the result structure
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
@ -364,16 +360,15 @@ class E2eKeysHandler(object):
log_kv(results) log_kv(results)
return result_dict return result_dict
@defer.inlineCallbacks async def on_federation_query_client_keys(self, query_body):
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server """ Handle a device key query from a federated server
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res} ret = {"device_keys": res}
# add in the cross-signing keys # add in the cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache( cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None device_keys_query, None
) )
@ -382,8 +377,7 @@ class E2eKeysHandler(object):
return ret return ret
@trace @trace
@defer.inlineCallbacks async def claim_one_time_keys(self, query, timeout):
def claim_one_time_keys(self, query, timeout):
local_query = [] local_query = []
remote_queries = {} remote_queries = {}
@ -399,7 +393,7 @@ class E2eKeysHandler(object):
set_tag("local_key_query", local_query) set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", remote_queries)
results = yield self.store.claim_e2e_one_time_keys(local_query) results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {} json_result = {}
failures = {} failures = {}
@ -411,12 +405,11 @@ class E2eKeysHandler(object):
} }
@trace @trace
@defer.inlineCallbacks async def claim_client_keys(destination):
def claim_client_keys(destination):
set_tag("destination", destination) set_tag("destination", destination)
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
remote_result = yield self.federation.claim_client_keys( remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout destination, {"one_time_keys": device_keys}, timeout=timeout
) )
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
@ -429,7 +422,7 @@ class E2eKeysHandler(object):
set_tag("error", True) set_tag("error", True)
set_tag("reason", failure) set_tag("reason", failure)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(claim_client_keys, destination) run_in_background(claim_client_keys, destination)
@ -454,9 +447,8 @@ class E2eKeysHandler(object):
log_kv({"one_time_keys": json_result, "failures": failures}) log_kv({"one_time_keys": json_result, "failures": failures})
return {"one_time_keys": json_result, "failures": failures} return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
@tag_args @tag_args
def upload_keys_for_user(self, user_id, device_id, keys): async def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -477,12 +469,12 @@ class E2eKeysHandler(object):
} }
) )
# TODO: Sign the JSON with the server key # TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys( changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys user_id, device_id, time_now, device_keys
) )
if changed: if changed:
# Only notify about device updates *if* the keys actually changed # Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id]) await self.device_handler.notify_device_update(user_id, [device_id])
else: else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
@ -494,7 +486,7 @@ class E2eKeysHandler(object):
"device_id": device_id, "device_id": device_id,
} }
) )
yield self._upload_one_time_keys_for_user( await self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys user_id, device_id, time_now, one_time_keys
) )
else: else:
@ -507,15 +499,14 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we # old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with # need to double-check the device is registered to avoid ending up with
# keys without a corresponding device. # keys without a corresponding device.
yield self.device_handler.check_device_registered(user_id, device_id) await self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result) set_tag("one_time_key_counts", result)
return {"one_time_key_counts": result} return {"one_time_key_counts": result}
@defer.inlineCallbacks async def _upload_one_time_keys_for_user(
def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys self, user_id, device_id, time_now, one_time_keys
): ):
logger.info( logger.info(
@ -533,7 +524,7 @@ class E2eKeysHandler(object):
key_list.append((algorithm, key_id, key_obj)) key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys. # First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys( existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list] user_id, device_id, [k_id for _, k_id, _ in key_list]
) )
@ -556,10 +547,9 @@ class E2eKeysHandler(object):
) )
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
@defer.inlineCallbacks async def upload_signing_keys_for_user(self, user_id, keys):
def upload_signing_keys_for_user(self, user_id, keys):
"""Upload signing keys for cross-signing """Upload signing keys for cross-signing
Args: Args:
@ -574,7 +564,7 @@ class E2eKeysHandler(object):
_check_cross_signing_key(master_key, user_id, "master") _check_cross_signing_key(master_key, user_id, "master")
else: else:
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
# if there is no master key, then we can't do anything, because all the # if there is no master key, then we can't do anything, because all the
# other cross-signing keys need to be signed by the master key # other cross-signing keys need to be signed by the master key
@ -613,10 +603,10 @@ class E2eKeysHandler(object):
# if everything checks out, then store the keys and send notifications # if everything checks out, then store the keys and send notifications
deviceids = [] deviceids = []
if "master_key" in keys: if "master_key" in keys:
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
deviceids.append(master_verify_key.version) deviceids.append(master_verify_key.version)
if "self_signing_key" in keys: if "self_signing_key" in keys:
yield self.store.set_e2e_cross_signing_key( await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key user_id, "self_signing", self_signing_key
) )
try: try:
@ -626,23 +616,22 @@ class E2eKeysHandler(object):
except ValueError: except ValueError:
raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
if "user_signing_key" in keys: if "user_signing_key" in keys:
yield self.store.set_e2e_cross_signing_key( await self.store.set_e2e_cross_signing_key(
user_id, "user_signing", user_signing_key user_id, "user_signing", user_signing_key
) )
# the signature stream matches the semantics that we want for # the signature stream matches the semantics that we want for
# user-signing key updates: only the user themselves is notified of # user-signing key updates: only the user themselves is notified of
# their own user-signing key updates # their own user-signing key updates
yield self.device_handler.notify_user_signature_update(user_id, [user_id]) await self.device_handler.notify_user_signature_update(user_id, [user_id])
# master key and self-signing key updates match the semantics of device # master key and self-signing key updates match the semantics of device
# list updates: all users who share an encrypted room are notified # list updates: all users who share an encrypted room are notified
if len(deviceids): if len(deviceids):
yield self.device_handler.notify_device_update(user_id, deviceids) await self.device_handler.notify_device_update(user_id, deviceids)
return {} return {}
@defer.inlineCallbacks async def upload_signatures_for_device_keys(self, user_id, signatures):
def upload_signatures_for_device_keys(self, user_id, signatures):
"""Upload device signatures for cross-signing """Upload device signatures for cross-signing
Args: Args:
@ -667,13 +656,13 @@ class E2eKeysHandler(object):
self_signatures = signatures.get(user_id, {}) self_signatures = signatures.get(user_id, {})
other_signatures = {k: v for k, v in signatures.items() if k != user_id} other_signatures = {k: v for k, v in signatures.items() if k != user_id}
self_signature_list, self_failures = yield self._process_self_signatures( self_signature_list, self_failures = await self._process_self_signatures(
user_id, self_signatures user_id, self_signatures
) )
signature_list.extend(self_signature_list) signature_list.extend(self_signature_list)
failures.update(self_failures) failures.update(self_failures)
other_signature_list, other_failures = yield self._process_other_signatures( other_signature_list, other_failures = await self._process_other_signatures(
user_id, other_signatures user_id, other_signatures
) )
signature_list.extend(other_signature_list) signature_list.extend(other_signature_list)
@ -681,21 +670,20 @@ class E2eKeysHandler(object):
# store the signature, and send the appropriate notifications for sync # store the signature, and send the appropriate notifications for sync
logger.debug("upload signature failures: %r", failures) logger.debug("upload signature failures: %r", failures)
yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
self_device_ids = [item.target_device_id for item in self_signature_list] self_device_ids = [item.target_device_id for item in self_signature_list]
if self_device_ids: if self_device_ids:
yield self.device_handler.notify_device_update(user_id, self_device_ids) await self.device_handler.notify_device_update(user_id, self_device_ids)
signed_users = [item.target_user_id for item in other_signature_list] signed_users = [item.target_user_id for item in other_signature_list]
if signed_users: if signed_users:
yield self.device_handler.notify_user_signature_update( await self.device_handler.notify_user_signature_update(
user_id, signed_users user_id, signed_users
) )
return {"failures": failures} return {"failures": failures}
@defer.inlineCallbacks async def _process_self_signatures(self, user_id, signatures):
def _process_self_signatures(self, user_id, signatures):
"""Process uploaded signatures of the user's own keys. """Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms: Signatures of the user's own keys from this API come in two forms:
@ -728,7 +716,7 @@ class E2eKeysHandler(object):
_, _,
self_signing_key_id, self_signing_key_id,
self_signing_verify_key, self_signing_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it. # get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so # We need to fetch it here so that we know what its key ID is, so
@ -738,12 +726,12 @@ class E2eKeysHandler(object):
master_key, master_key,
_, _,
master_verify_key, master_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") ) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify # fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what # signatures on the master key, and 2. to compare with what
# was sent if the device was signed # was sent if the device was signed
devices = yield self.store.get_e2e_device_keys([(user_id, None)]) devices = await self.store.get_e2e_device_keys([(user_id, None)])
if user_id not in devices: if user_id not in devices:
raise NotFoundError("No device keys found") raise NotFoundError("No device keys found")
@ -853,8 +841,7 @@ class E2eKeysHandler(object):
return master_key_signature_list return master_key_signature_list
@defer.inlineCallbacks async def _process_other_signatures(self, user_id, signatures):
def _process_other_signatures(self, user_id, signatures):
"""Process uploaded signatures of other users' keys. These will be the """Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing target user's master keys, signed by the uploading user's user-signing
key. key.
@ -882,7 +869,7 @@ class E2eKeysHandler(object):
user_signing_key, user_signing_key,
user_signing_key_id, user_signing_key_id,
user_signing_verify_key, user_signing_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e: except SynapseError as e:
failure = _exception_to_failure(e) failure = _exception_to_failure(e)
for user, devicemap in signatures.items(): for user, devicemap in signatures.items():
@ -905,7 +892,7 @@ class E2eKeysHandler(object):
master_key, master_key,
master_key_id, master_key_id,
_, _,
) = yield self._get_e2e_cross_signing_verify_key( ) = await self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id target_user, "master", user_id
) )
@ -958,8 +945,7 @@ class E2eKeysHandler(object):
return signature_list, failures return signature_list, failures
@defer.inlineCallbacks async def _get_e2e_cross_signing_verify_key(
def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None self, user_id: str, key_type: str, from_user_id: str = None
): ):
"""Fetch locally or remotely query for a cross-signing public key. """Fetch locally or remotely query for a cross-signing public key.
@ -983,7 +969,7 @@ class E2eKeysHandler(object):
SynapseError: if `user_id` is invalid SynapseError: if `user_id` is invalid
""" """
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
key = yield self.store.get_e2e_cross_signing_key( key = await self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id user_id, key_type, from_user_id
) )
@ -1009,15 +995,14 @@ class E2eKeysHandler(object):
key, key,
key_id, key_id,
verify_key, verify_key,
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None: if key is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id)) raise NotFoundError("No %s key found for %s" % (key_type, user_id))
return key, key_id, verify_key return key, key_id, verify_key
@defer.inlineCallbacks async def _retrieve_cross_signing_keys_for_remote_user(
def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str, self, user: UserID, desired_key_type: str,
): ):
"""Queries cross-signing keys for a remote user and saves them to the database """Queries cross-signing keys for a remote user and saves them to the database
@ -1035,7 +1020,7 @@ class E2eKeysHandler(object):
If the key cannot be retrieved, all values in the tuple will instead be None. If the key cannot be retrieved, all values in the tuple will instead be None.
""" """
try: try:
remote_result = yield self.federation.query_user_devices( remote_result = await self.federation.query_user_devices(
user.domain, user.to_string() user.domain, user.to_string()
) )
except Exception as e: except Exception as e:
@ -1101,14 +1086,14 @@ class E2eKeysHandler(object):
desired_key_id = key_id desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries # At the same time, store this key in the db for subsequent queries
yield self.store.set_e2e_cross_signing_key( await self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content user.to_string(), key_type, key_content
) )
# Notify clients that new devices for this user have been discovered # Notify clients that new devices for this user have been discovered
if retrieved_device_ids: if retrieved_device_ids:
# XXX is this necessary? # XXX is this necessary?
yield self.device_handler.notify_device_update( await self.device_handler.notify_device_update(
user.to_string(), retrieved_device_ids user.to_string(), retrieved_device_ids
) )
@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object):
iterable=True, iterable=True,
) )
@defer.inlineCallbacks async def incoming_signing_key_update(self, origin, edu_content):
def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for """Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list. parsing the EDU and adding to pending updates list.
@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object):
logger.warning("Got signing key update edu for %r from %r", user_id, origin) logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return return
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids: if not room_ids:
# We don't share any rooms with this user. Ignore update, as we # We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates. # probably won't get any further updates.
@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object):
(master_key, self_signing_key) (master_key, self_signing_key)
) )
yield self._handle_signing_key_updates(user_id) await self._handle_signing_key_updates(user_id)
@defer.inlineCallbacks async def _handle_signing_key_updates(self, user_id):
def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates. """Actually handle pending updates.
Args: Args:
@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object):
device_handler = self.e2e_keys_handler.device_handler device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater device_list_updater = device_handler.device_list_updater
with (yield self._remote_edu_linearizer.queue(user_id)): with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, []) pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates: if not pending_updates:
# This can happen since we batch updates # This can happen since we batch updates
@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates: for master_key, self_signing_key in pending_updates:
new_device_ids = yield device_list_updater.process_cross_signing_key_update( new_device_ids = await device_list_updater.process_cross_signing_key_update(
user_id, master_key, self_signing_key, user_id, master_key, self_signing_key,
) )
device_ids = device_ids + new_device_ids device_ids = device_ids + new_device_ids
yield device_handler.notify_device_update(user_id, device_ids) await device_handler.notify_device_update(user_id, device_ids)

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
NotFoundError, NotFoundError,
@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
self._upload_linearizer = Linearizer("upload_room_keys_lock") self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace @trace
@defer.inlineCallbacks async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details. See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
# we deliberately take the lock to get keys so that changing the version # we deliberately take the lock to get keys so that changing the version
# works atomically # works atomically
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists # make sure the backup version exists
try: try:
yield self.store.get_e2e_room_keys_version_info(user_id, version) await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown backup version") raise NotFoundError("Unknown backup version")
else: else:
raise raise
results = yield self.store.get_e2e_room_keys( results = await self.store.get_e2e_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
return results return results
@trace @trace
@defer.inlineCallbacks async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
""" """
# lock for consistency with uploading # lock for consistency with uploading
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists # make sure the backup version exists
try: try:
version_info = yield self.store.get_e2e_room_keys_version_info( version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version user_id, version
) )
except StoreError as e: except StoreError as e:
@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
else: else:
raise raise
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
version_etag = version_info["etag"] + 1 version_etag = version_info["etag"] + 1
yield self.store.update_e2e_room_keys_version( await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag user_id, version, None, version_etag
) )
count = yield self.store.count_e2e_room_keys(user_id, version) count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@trace @trace
@defer.inlineCallbacks async def upload_room_keys(self, user_id, version, room_keys):
def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting """Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT(). into the current backup as described in RoomKeysServlet.on_PUT().
@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here? # XXX: perhaps we should use a finer grained lock here?
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
# Check that the version we're trying to upload is the current version # Check that the version we're trying to upload is the current version
try: try:
version_info = yield self.store.get_e2e_room_keys_version_info(user_id) version_info = await self.store.get_e2e_room_keys_version_info(user_id)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,)) raise NotFoundError("Version '%s' not found" % (version,))
@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
if version_info["version"] != version: if version_info["version"] != version:
# Check that the version we're trying to upload actually exists # Check that the version we're trying to upload actually exists
try: try:
version_info = yield self.store.get_e2e_room_keys_version_info( version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version user_id, version
) )
# if we get this far, the version must exist # if we get this far, the version must exist
@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
# submitted. Then compare them with the submitted keys. If the # submitted. Then compare them with the submitted keys. If the
# key is new, insert it; if the key should be updated, then update # key is new, insert it; if the key should be updated, then update
# it; otherwise, drop it. # it; otherwise, drop it.
existing_keys = yield self.store.get_e2e_room_keys_multi( existing_keys = await self.store.get_e2e_room_keys_multi(
user_id, version, room_keys["rooms"] user_id, version, room_keys["rooms"]
) )
to_insert = [] # batch the inserts together to_insert = [] # batch the inserts together
@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
# updates are done one at a time in the DB, so send # updates are done one at a time in the DB, so send
# updates right away rather than batching them up, # updates right away rather than batching them up,
# like we do with the inserts # like we do with the inserts
yield self.store.update_e2e_room_key( await self.store.update_e2e_room_key(
user_id, version, room_id, session_id, room_key user_id, version, room_id, session_id, room_key
) )
changed = True changed = True
@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
changed = True changed = True
if len(to_insert): if len(to_insert):
yield self.store.add_e2e_room_keys(user_id, version, to_insert) await self.store.add_e2e_room_keys(user_id, version, to_insert)
version_etag = version_info["etag"] version_etag = version_info["etag"]
if changed: if changed:
version_etag = version_etag + 1 version_etag = version_etag + 1
yield self.store.update_e2e_room_keys_version( await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag user_id, version, None, version_etag
) )
count = yield self.store.count_e2e_room_keys(user_id, version) count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@staticmethod @staticmethod
@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
return True return True
@trace @trace
@defer.inlineCallbacks async def create_version(self, user_id, version_info):
def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new """Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be backup version for the user's keys; previous backups will no longer be
writeable to. writeable to.
@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version # lock everyone out until we've switched version
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
new_version = yield self.store.create_e2e_room_keys_version( new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info user_id, version_info
) )
return new_version return new_version
@defer.inlineCallbacks async def get_version_info(self, user_id, version=None):
def get_version_info(self, user_id, version=None):
"""Get the info about a given version of the user's backup """Get the info about a given version of the user's backup
Args: Args:
@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
} }
""" """
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
try: try:
res = yield self.store.get_e2e_room_keys_version_info(user_id, version) res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown backup version") raise NotFoundError("Unknown backup version")
else: else:
raise raise
res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
res["etag"] = str(res["etag"]) res["etag"] = str(res["etag"])
return res return res
@trace @trace
@defer.inlineCallbacks async def delete_version(self, user_id, version=None):
def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup """Deletes a given version of the user's e2e_room_keys backup
Args: Args:
@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
NotFoundError: if this backup version doesn't exist NotFoundError: if this backup version doesn't exist
""" """
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
try: try:
yield self.store.delete_e2e_room_keys_version(user_id, version) await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown backup version") raise NotFoundError("Unknown backup version")
@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
raise raise
@trace @trace
@defer.inlineCallbacks async def update_version(self, user_id, version, version_info):
def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup """Update the info about a given version of the user's backup
Args: Args:
@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
raise SynapseError( raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM 400, "Version in body does not match", Codes.INVALID_PARAM
) )
with (yield self._upload_linearizer.queue(user_id)): with (await self._upload_linearizer.queue(user_id)):
try: try:
old_info = yield self.store.get_e2e_room_keys_version_info( old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version user_id, version
) )
except StoreError as e: except StoreError as e:
@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
if old_info["algorithm"] != version_info["algorithm"]: if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
yield self.store.update_e2e_room_keys_version( await self.store.update_e2e_room_keys_version(
user_id, version, version_info user_id, version, version_info
) )

View File

@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list. """If the user has no devices, we expect an empty list.
""" """
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None}) res = yield defer.ensureDeferred(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -60,16 +62,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
} }
res = yield self.handler.upload_keys_for_user( res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem # we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2" keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user( res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -84,37 +90,48 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
} }
res = yield self.handler.upload_keys_for_user( res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try: try:
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
) )
)
self.fail("No error when changing string key") self.fail("No error when changing string key")
except errors.SynapseError: except errors.SynapseError:
pass pass
try: try:
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
) )
)
self.fail("No error when replacing dict key with string") self.fail("No error when replacing dict key with string")
except errors.SynapseError: except errors.SynapseError:
pass pass
try: try:
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} self.handler.upload_keys_for_user(
local_user,
device_id,
{"one_time_keys": {"alg1:k1": {"key": "key"}}},
)
) )
self.fail("No error when replacing string key with dict") self.fail("No error when replacing string key with dict")
except errors.SynapseError: except errors.SynapseError:
pass pass
try: try:
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, local_user,
device_id, device_id,
{ {
@ -123,6 +140,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
} }
}, },
) )
)
self.fail("No error when replacing dict key") self.fail("No error when replacing dict key")
except errors.SynapseError: except errors.SynapseError:
pass pass
@ -133,14 +151,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"} keys = {"alg1:k1": "key1"}
res = yield self.handler.upload_keys_for_user( res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
res2 = yield self.handler.claim_one_time_keys( res2 = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
)
self.assertEqual( self.assertEqual(
res2, res2,
{ {
@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield self.handler.upload_signing_keys_for_user(local_user, keys1) yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
keys2 = { keys2 = {
"master_key": { "master_key": {
@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield self.handler.upload_signing_keys_for_user(local_user, keys2) yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys2)
)
devices = yield self.handler.query_devices( devices = yield defer.ensureDeferred(
{"device_keys": {local_user: []}}, 0, local_user self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
) )
yield self.handler.upload_signing_keys_for_user(local_user, keys1) yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
# upload two device keys, which will be signed later by the self-signing key # upload two device keys, which will be signed later by the self-signing key
device_key_1 = { device_key_1 = {
@ -245,33 +273,41 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}}, "signatures": {local_user: {"ed25519:def": "base64+signature"}},
} }
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1} local_user, "abc", {"device_keys": device_key_1}
) )
yield self.handler.upload_keys_for_user( )
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2} local_user, "def", {"device_keys": device_key_2}
) )
)
# sign the first device key and upload it # sign the first device key and upload it
del device_key_1["signatures"] del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key) sign.sign_json(device_key_1, local_user, signing_key)
yield self.handler.upload_signatures_for_device_keys( yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}} local_user, {local_user: {"abc": device_key_1}}
) )
)
# sign the second device key and upload both device keys. The server # sign the second device key and upload both device keys. The server
# should ignore the first device key since it already has a valid # should ignore the first device key since it already has a valid
# signature for it # signature for it
del device_key_2["signatures"] del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key) sign.sign_json(device_key_2, local_user, signing_key)
yield self.handler.upload_signatures_for_device_keys( yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
) )
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = yield self.handler.query_devices( devices = yield defer.ensureDeferred(
{"device_keys": {local_user: []}}, 0, local_user self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
) )
del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"]
@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield self.handler.upload_signing_keys_for_user(local_user, keys1) yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
res = None res = None
try: try:
@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = e.code res = e.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
res = yield self.handler.query_local_devices({local_user: None}) res = yield defer.ensureDeferred(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -331,9 +371,11 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
) )
yield self.handler.upload_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key} local_user, device_id, {"device_keys": device_key}
) )
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key, "user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key, "self_signing_key": selfsigning_key,
} }
yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
)
# set up another user with a master key. This user will be signed by # set up another user with a master key. This user will be signed by
# the first user # the first user
@ -384,12 +428,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"], "usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
} }
yield self.handler.upload_signing_keys_for_user( yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key} other_user, {"master_key": other_master_key}
) )
)
# test various signature failures (see below) # test various signature failures (see below)
ret = yield self.handler.upload_signatures_for_device_keys( ret = yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, local_user,
{ {
local_user: { local_user: {
@ -408,7 +455,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519:xyz": device_pubkey, "ed25519:xyz": device_pubkey,
}, },
"signatures": { "signatures": {
local_user: {"ed25519:" + selfsigning_pubkey: "something"} local_user: {
"ed25519:" + selfsigning_pubkey: "something"
}
}, },
}, },
# fails because device is unknown # fails because device is unknown
@ -417,7 +466,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_id": local_user, "user_id": local_user,
"device_id": "unknown", "device_id": "unknown",
"signatures": { "signatures": {
local_user: {"ed25519:" + selfsigning_pubkey: "something"} local_user: {
"ed25519:" + selfsigning_pubkey: "something"
}
}, },
}, },
# fails because the signature is invalid # fails because the signature is invalid
@ -438,7 +489,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_id": other_user, "user_id": other_user,
"device_id": "unknown", "device_id": "unknown",
"signatures": { "signatures": {
local_user: {"ed25519:" + usersigning_pubkey: "something"} local_user: {
"ed25519:" + usersigning_pubkey: "something"
}
}, },
}, },
other_master_pubkey: { other_master_pubkey: {
@ -446,15 +499,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# should fail with UNKNOWN # should fail with UNKNOWN
"user_id": other_user, "user_id": other_user,
"usage": ["master"], "usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, "keys": {
"ed25519:" + other_master_pubkey: other_master_pubkey
},
"something": "random", "something": "random",
"signatures": { "signatures": {
local_user: {"ed25519:" + usersigning_pubkey: "something"} local_user: {
"ed25519:" + usersigning_pubkey: "something"
}
}, },
}, },
}, },
}, },
) )
)
user_failures = ret["failures"][local_user] user_failures = ret["failures"][local_user]
self.assertEqual( self.assertEqual(
@ -478,20 +536,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key)
ret = yield self.handler.upload_signatures_for_device_keys( ret = yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, local_user,
{ {
local_user: {device_id: device_key, master_pubkey: master_key}, local_user: {device_id: device_key, master_pubkey: master_key},
other_user: {other_master_pubkey: other_master_key}, other_user: {other_master_pubkey: other_master_key},
}, },
) )
)
self.assertEqual(ret["failures"], {}) self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there # fetch the signed keys/devices and make sure that the signatures are there
ret = yield self.handler.query_devices( ret = yield defer.ensureDeferred(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user {"device_keys": {local_user: [], other_user: []}}, 0, local_user
) )
)
self.assertEqual( self.assertEqual(
ret["device_keys"][local_user]["xyz"]["signatures"][local_user][ ret["device_keys"][local_user]["xyz"]["signatures"][local_user][

View File

@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.get_version_info(self.local_user) yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.get_version_info(self.local_user, "bogus_version") yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "bogus_version")
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self): def test_create_version(self):
"""Check that we can create and then retrieve versions. """Check that we can create and then retrieve versions.
""" """
res = yield self.handler.create_version( res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(res, "1") self.assertEqual(res, "1")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"] version_etag = res["etag"]
self.assertIsInstance(version_etag, str) self.assertIsInstance(version_etag, str)
del res["etag"] del res["etag"]
@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
# check we can retrieve it as a specific version # check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1") res = yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "1")
)
self.assertEqual(res["etag"], version_etag) self.assertEqual(res["etag"], version_etag)
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
# upload a new one... # upload a new one...
res = yield self.handler.create_version( res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{ {
"algorithm": "m.megolm_backup.v1", "algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data", "auth_data": "second_version_auth_data",
}, },
) )
)
self.assertEqual(res, "2") self.assertEqual(res, "2")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -149,13 +160,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self): def test_update_version(self):
"""Check that we can update versions. """Check that we can update versions.
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = yield self.handler.update_version( res = yield defer.ensureDeferred(
self.handler.update_version(
self.local_user, self.local_user,
version, version,
{ {
@ -164,10 +181,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"version": version, "version": version,
}, },
) )
)
self.assertDictEqual(res, {}) self.assertDictEqual(res, {})
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -185,7 +203,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.update_version( yield defer.ensureDeferred(
self.handler.update_version(
self.local_user, self.local_user,
"1", "1",
{ {
@ -194,6 +213,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"version": "1", "version": "1",
}, },
) )
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -202,13 +222,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self): def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body """Check that the update succeeds if the version is missing from the body
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield self.handler.update_version( yield defer.ensureDeferred(
self.handler.update_version(
self.local_user, self.local_user,
version, version,
{ {
@ -216,9 +242,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"auth_data": "revised_first_version_auth_data", "auth_data": "revised_first_version_auth_data",
}, },
) )
)
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -234,15 +261,21 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self): def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match """Check that we get a 400 if the version in the body doesn't match
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = None res = None
try: try:
yield self.handler.update_version( yield defer.ensureDeferred(
self.handler.update_version(
self.local_user, self.local_user,
version, version,
{ {
@ -251,6 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"version": "incorrect", "version": "incorrect",
}, },
) )
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.delete_version(self.local_user, "1") yield defer.ensureDeferred(
self.handler.delete_version(self.local_user, "1")
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.delete_version(self.local_user) yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self): def test_delete_version(self):
"""Check that we can create and then delete versions. """Check that we can create and then delete versions.
""" """
res = yield self.handler.create_version( res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(res, "1") self.assertEqual(res, "1")
# check we can delete it # check we can delete it
yield self.handler.delete_version(self.local_user, "1") yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone # check that it's gone
res = None res = None
try: try:
yield self.handler.get_version_info(self.local_user, "1") yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "1")
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.get_room_keys(self.local_user, "bogus_version") yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, "bogus_version")
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self): def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup """Check we get an empty response from an empty backup
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = yield self.handler.get_room_keys(self.local_user, version) res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys, # TODO: test the locking semantics when uploading room_keys,
@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
""" """
res = None res = None
try: try:
yield self.handler.upload_room_keys( yield defer.ensureDeferred(
self.local_user, "no_version", room_keys self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
) )
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
@ -343,17 +395,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version """Check that we get a 404 on uploading keys when an nonexistent version
is specified is specified
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = None res = None
try: try:
yield self.handler.upload_room_keys( yield defer.ensureDeferred(
self.handler.upload_room_keys(
self.local_user, "bogus_version", room_keys self.local_user, "bogus_version", room_keys
) )
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self): def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version """Check that we get a 403 on uploading keys for an old version
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{ {
"algorithm": "m.megolm_backup.v1", "algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data", "auth_data": "second_version_auth_data",
}, },
) )
)
self.assertEqual(version, "2") self.assertEqual(version, "2")
res = None res = None
try: try:
yield self.handler.upload_room_keys(self.local_user, "1", room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, "1", room_keys)
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 403) self.assertEqual(res, 403)
@ -388,43 +456,63 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self): def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session """Check that we can insert and retrieve keys for a session
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version) res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room # check getting room_keys for a given room
res = yield self.handler.get_room_keys( res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org" self.local_user, version, room_id="!abc:matrix.org"
) )
)
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id # check getting room_keys for a given session_id
res = yield self.handler.get_room_keys( res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
)
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_upload_room_keys_merge(self): def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and """Check that we can upload a new room_key for an existing session and
have it correctly merged""" have it correctly merged"""
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
# get the etag to compare to future versions # get the etag to compare to future versions
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"] backup_etag = res["etag"]
self.assertEqual(res["count"], 1) self.assertEqual(res["count"], 1)
@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session # test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2 new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new" new_room_key["session_data"] = "new"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version) res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK", "SSBBTSBBIEZJU0gK",
) )
# the etag should be the same since the session did not change # the etag should be the same since the session did not change
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag) self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it # test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version) res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
) )
# the etag should NOT be equal now, since the key changed # the etag should NOT be equal now, since the key changed
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag) self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"] backup_etag = res["etag"]
@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count # with a lower forwarding count
new_room_key["forwarded_count"] = 2 new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other" new_room_key["session_data"] = "other"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version) res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
) )
# the etag should be the same since the session did not change # the etag should be the same since the session did not change
res = yield self.handler.get_version_info(self.local_user) res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag) self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here # TODO: check edge cases as well as the common variations here
@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self): def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session """Check that we can insert and delete keys for a session
""" """
version = yield self.handler.create_version( version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user, self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, {
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
# check for bulk-delete # check for bulk-delete
yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield defer.ensureDeferred(
yield self.handler.delete_room_keys(self.local_user, version) self.handler.upload_room_keys(self.local_user, version, room_keys)
res = yield self.handler.get_room_keys( )
yield defer.ensureDeferred(
self.handler.delete_room_keys(self.local_user, version)
)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
)
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room # check for bulk-delete per room
yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield defer.ensureDeferred(
yield self.handler.delete_room_keys( self.handler.upload_room_keys(self.local_user, version, room_keys)
)
yield defer.ensureDeferred(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org" self.local_user, version, room_id="!abc:matrix.org"
) )
res = yield self.handler.get_room_keys( )
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
)
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session # check for bulk-delete per session
yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield defer.ensureDeferred(
yield self.handler.delete_room_keys( self.handler.upload_room_keys(self.local_user, version, room_keys)
)
yield defer.ensureDeferred(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
res = yield self.handler.get_room_keys( )
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
)
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})