mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Clean up verify_json_objects_for_server
This commit is contained in:
parent
c63b1697f4
commit
fe1b369946
@ -44,7 +44,21 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||||
|
"server_name", "key_ids", "json_object", "deferred"
|
||||||
|
))
|
||||||
|
"""
|
||||||
|
A request for a verify key to verify a JSON object.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name(str): The name of the server to verify against.
|
||||||
|
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||||
|
JSON object
|
||||||
|
json_object(dict): The JSON object to verify.
|
||||||
|
deferred(twisted.internet.defer.Deferred):
|
||||||
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
|
a verify key has been fetched
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
@ -74,39 +88,32 @@ class Keyring(object):
|
|||||||
list of deferreds indicating success or failure to verify each
|
list of deferreds indicating success or failure to verify each
|
||||||
json object's signature for the given server_name.
|
json object's signature for the given server_name.
|
||||||
"""
|
"""
|
||||||
group_id_to_json = {}
|
verify_requests = []
|
||||||
group_id_to_group = {}
|
|
||||||
group_ids = []
|
|
||||||
|
|
||||||
next_group_id = 0
|
|
||||||
deferreds = {}
|
|
||||||
|
|
||||||
for server_name, json_object in server_and_json:
|
for server_name, json_object in server_and_json:
|
||||||
logger.debug("Verifying for %s", server_name)
|
logger.debug("Verifying for %s", server_name)
|
||||||
group_id = next_group_id
|
|
||||||
next_group_id += 1
|
|
||||||
group_ids.append(group_id)
|
|
||||||
|
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
deferreds[group_id] = defer.fail(SynapseError(
|
deferred = defer.fail(SynapseError(
|
||||||
400,
|
400,
|
||||||
"Not signed with a supported algorithm",
|
"Not signed with a supported algorithm",
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
deferreds[group_id] = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
|
|
||||||
group = KeyGroup(server_name, group_id, key_ids)
|
verify_request = VerifyKeyRequest(
|
||||||
|
server_name, key_ids, json_object, deferred
|
||||||
|
)
|
||||||
|
|
||||||
group_id_to_group[group_id] = group
|
verify_requests.append(verify_request)
|
||||||
group_id_to_json[group_id] = json_object
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_key_deferred(group, deferred):
|
def handle_key_deferred(verify_request):
|
||||||
server_name = group.server_name
|
server_name = verify_request.server_name
|
||||||
try:
|
try:
|
||||||
_, _, key_id, verify_key = yield deferred
|
_, key_id, verify_key = yield verify_request.deferred
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
@ -128,7 +135,7 @@ class Keyring(object):
|
|||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
json_object = group_id_to_json[group.group_id]
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
@ -157,36 +164,34 @@ class Keyring(object):
|
|||||||
|
|
||||||
# Actually start fetching keys.
|
# Actually start fetching keys.
|
||||||
wait_on_deferred.addBoth(
|
wait_on_deferred.addBoth(
|
||||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
lambda _: self.get_server_verify_keys(verify_requests)
|
||||||
)
|
)
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
# any lookups waiting will proceed.
|
# any lookups waiting will proceed.
|
||||||
server_to_gids = {}
|
server_to_request_ids = {}
|
||||||
|
|
||||||
def remove_deferreds(res, server_name, group_id):
|
def remove_deferreds(res, server_name, verify_request):
|
||||||
server_to_gids[server_name].discard(group_id)
|
request_id = id(verify_request)
|
||||||
if not server_to_gids[server_name]:
|
server_to_request_ids[server_name].discard(request_id)
|
||||||
|
if not server_to_request_ids[server_name]:
|
||||||
d = server_to_deferred.pop(server_name, None)
|
d = server_to_deferred.pop(server_name, None)
|
||||||
if d:
|
if d:
|
||||||
d.callback(None)
|
d.callback(None)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
for g_id, deferred in deferreds.items():
|
for verify_request in verify_requests:
|
||||||
server_name = group_id_to_group[g_id].server_name
|
server_name = verify_request.server_name
|
||||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
request_id = id(verify_request)
|
||||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||||
|
|
||||||
# Pass those keys to handle_key_deferred so that the json object
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
# signatures can be verified
|
# signatures can be verified
|
||||||
return [
|
return [
|
||||||
preserve_context_over_fn(
|
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||||
handle_key_deferred,
|
for verify_request in verify_requests
|
||||||
group_id_to_group[g_id],
|
|
||||||
deferreds[g_id],
|
|
||||||
)
|
|
||||||
for g_id in group_ids
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -220,7 +225,7 @@ class Keyring(object):
|
|||||||
|
|
||||||
d.addBoth(rm, server_name)
|
d.addBoth(rm, server_name)
|
||||||
|
|
||||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
def get_server_verify_keys(self, verify_requests):
|
||||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
each group.
|
each group.
|
||||||
"""
|
"""
|
||||||
@ -237,62 +242,64 @@ class Keyring(object):
|
|||||||
merged_results = {}
|
merged_results = {}
|
||||||
|
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
for verify_request in verify_requests:
|
||||||
missing_keys.setdefault(group.server_name, set()).update(
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
group.key_ids
|
verify_request.key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for fn in key_fetch_fns:
|
for fn in key_fetch_fns:
|
||||||
results = yield fn(missing_keys.items())
|
results = yield fn(missing_keys.items())
|
||||||
merged_results.update(results)
|
merged_results.update(results)
|
||||||
|
|
||||||
# We now need to figure out which groups we have keys for
|
# We now need to figure out which verify requests we have keys
|
||||||
# and which we don't
|
# for and which we don't
|
||||||
missing_groups = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
requests_missing_keys = []
|
||||||
for key_id in group.key_ids:
|
for verify_request in verify_requests:
|
||||||
if key_id in merged_results[group.server_name]:
|
server_name = verify_request.server_name
|
||||||
|
result_keys = merged_results[server_name]
|
||||||
|
|
||||||
|
if verify_request.deferred.called:
|
||||||
|
# We've already called this deferred, which probably
|
||||||
|
# means that we've already found a key for it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
if key_id in result_keys:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
group_id_to_deferred[group.group_id].callback((
|
verify_request.deferred.callback((
|
||||||
group.group_id,
|
server_name,
|
||||||
group.server_name,
|
|
||||||
key_id,
|
key_id,
|
||||||
merged_results[group.server_name][key_id],
|
result_keys[key_id],
|
||||||
))
|
))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
missing_groups.setdefault(
|
# The else block is only reached if the loop above
|
||||||
group.server_name, []
|
# doesn't break.
|
||||||
).append(group)
|
missing_keys.setdefault(server_name, set()).update(
|
||||||
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
requests_missing_keys.append(verify_request)
|
||||||
|
|
||||||
if not missing_groups:
|
if not missing_keys:
|
||||||
break
|
break
|
||||||
|
|
||||||
missing_keys = {
|
for verify_request in requests_missing_keys.values():
|
||||||
server_name: set(
|
verify_request.deferred.errback(SynapseError(
|
||||||
key_id for group in groups for key_id in group.key_ids
|
|
||||||
)
|
|
||||||
for server_name, groups in missing_groups.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
for group in missing_groups.values():
|
|
||||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
|
||||||
401,
|
401,
|
||||||
"No key for %s with id %s" % (
|
"No key for %s with id %s" % (
|
||||||
group.server_name, group.key_ids,
|
verify_request.server_name, verify_request.key_ids,
|
||||||
),
|
),
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
for deferred in group_id_to_deferred.values():
|
for verify_request in verify_requests:
|
||||||
if not deferred.called:
|
if not verify_request.deferred.called:
|
||||||
deferred.errback(err)
|
verify_request.deferred.errback(err)
|
||||||
|
|
||||||
do_iterations().addErrback(on_err)
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
return group_id_to_deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
res = yield defer.gatherResults(
|
res = yield defer.gatherResults(
|
||||||
|
Loading…
Reference in New Issue
Block a user