diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 873c9b40f..aa74d4d0c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter from synapse.util import unwrapFirstError +from synapse.util.async import ObservableDeferred + from OpenSSL import crypto from collections import namedtuple @@ -88,6 +90,8 @@ class Keyring(object): "Not signed with a supported algorithm", Codes.UNAUTHORIZED, )) + else: + deferreds[group_id] = defer.Deferred() group = KeyGroup(server_name, group_id, key_ids) @@ -133,10 +137,41 @@ class Keyring(object): Codes.UNAUTHORIZED, ) - deferreds.update(self.get_server_verify_keys( - group_id_to_group - )) + server_to_deferred = { + server_name: defer.Deferred() + for server_name, _ in server_and_json + } + # We want to wait for any previous lookups to complete before + # proceeding. + wait_on_deferred = self.wait_for_previous_lookups( + [server_name for server_name, _ in server_and_json], + server_to_deferred, + ) + + # Actually start fetching keys. + wait_on_deferred.addBoth( + lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + ) + + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + server_to_gids = {} + + def remove_deferreds(res, server_name, group_id): + server_to_gids[server_name].discard(group_id) + if not server_to_gids[server_name]: + server_to_deferred.pop(server_name).callback(None) + return res + + for g_id, deferred in deferreds.items(): + server_name = group_id_to_group[g_id].server_name + server_to_gids.setdefault(server_name, set()).add(g_id) + deferred.addBoth(remove_deferreds, server_name, g_id) + + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified return [ handle_key_deferred( group_id_to_group[g_id], @@ -145,7 +180,30 @@ class Keyring(object): for g_id in group_ids ] - def get_server_verify_keys(self, group_id_to_group): + @defer.inlineCallbacks + def wait_for_previous_lookups(self, server_names, server_to_deferred): + """Waits for any previous key lookups for the given servers to finish. + + Args: + server_names (list): list of server_names we want to lookup + server_to_deferred (dict): server_name to deferred which gets + resolved once we've finished looking up keys for that server + """ + while True: + wait_on = [ + self.key_downloads[server_name] + for server_name in server_names + if server_name in self.key_downloads + ] + if wait_on: + yield defer.DeferredList(wait_on) + else: + break + + for server_name, deferred in server_to_deferred: + self.key_downloads[server_name] = ObservableDeferred(deferred) + + def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): """Takes a dict of KeyGroups and tries to find at least one key for each group. """ @@ -157,11 +215,6 @@ class Keyring(object): self.get_keys_from_server, # Then try directly ) - group_deferreds = { - group_id: defer.Deferred() - for group_id in group_id_to_group - } - @defer.inlineCallbacks def do_iterations(): merged_results = {} @@ -182,7 +235,7 @@ class Keyring(object): for group in group_id_to_group.values(): for key_id in group.key_ids: if key_id in merged_results[group.server_name]: - group_deferreds.pop(group.group_id).callback(( + group_id_to_deferred[group.group_id].callback(( group.group_id, group.server_name, key_id, @@ -205,7 +258,7 @@ class Keyring(object): } for group in missing_groups.values(): - group_deferreds.pop(group.group_id).errback(SynapseError( + group_id_to_deferred[group.group_id].errback(SynapseError( 401, "No key for %s with id %s" % ( group.server_name, group.key_ids, @@ -214,13 +267,13 @@ class Keyring(object): )) def on_err(err): - for deferred in group_deferreds.values(): - deferred.errback(err) - group_deferreds.clear() + for deferred in group_id_to_deferred.values(): + if not deferred.called: + deferred.errback(err) do_iterations().addErrback(on_err) - return group_deferreds + return group_id_to_deferred @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids):