Clean up Keyring.process_v2_response

Make this just return the key dict, rather than a single-entry dict mapping the
server name to the key dict. It's easy for the caller to get the server name
from from the response object anyway.
This commit is contained in:
Richard van der Hoff 2019-04-04 19:12:54 +01:00
parent b2d574f126
commit ef27d434d1

View File

@ -20,6 +20,7 @@ from collections import namedtuple
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
import nacl.signing
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -496,9 +497,9 @@ class Keyring(object):
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False perspective_name, response, only_from_server=False
) )
server_name = response["server_name"]
for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(processed_response)
keys.setdefault(server_name, {}).update(response_keys)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
@ -517,7 +518,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {} # type: dict[str, nacl.signing.VerifyKey]
for requested_key_id in key_ids: for requested_key_id in key_ids:
if requested_key_id in keys: if requested_key_id in keys:
@ -550,24 +551,49 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield self.store_keys(
[ server_name=server_name,
run_in_background(
self.store_keys,
server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=verify_keys, verify_keys=keys,
) )
for key_server_name, verify_keys in keys.items() defer.returnValue({server_name: keys})
],
consumeErrors=True
).addErrback(unwrapFirstError))
defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(
requested_ids=[], only_from_server=True): self, from_server, response_json, requested_ids=[], only_from_server=True
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
only_from_server (bool): if True, we will check that the server_name in the
the response (ie, the server which originated the key) matches
from_server.
Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
"""
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -589,7 +615,6 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server: if only_from_server:
if server_name != from_server: if server_name != from_server:
@ -643,9 +668,7 @@ class Keyring(object):
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError))
results[server_name] = response_keys defer.returnValue(response_keys)
defer.returnValue(results)
def store_keys(self, server_name, from_server, verify_keys): def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server """Store a collection of verify keys for a given server