Run black on synapse.crypto.keyring (#5232)

This commit is contained in:
Richard van der Hoff 2019-05-22 18:39:33 +01:00 committed by GitHub
parent 73f1de31d1
commit 1a94de60e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 138 additions and 149 deletions

1
changelog.d/5232.misc Normal file
View File

@ -0,0 +1 @@
Run black on synapse.crypto.keyring.

View File

@ -56,9 +56,9 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", ( VerifyKeyRequest = namedtuple(
"server_name", "key_ids", "json_object", "deferred" "VerifyRequest", ("server_name", "key_ids", "json_object", "deferred")
)) )
""" """
A request for a verify key to verify a JSON object. A request for a verify key to verify a JSON object.
@ -96,9 +96,7 @@ class Keyring(object):
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
return logcontext.make_deferred_yieldable( return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server( self.verify_json_objects_for_server([(server_name, json_object)])[0]
[(server_name, json_object)]
)[0]
) )
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
@ -130,18 +128,15 @@ class Keyring(object):
if not key_ids: if not key_ids:
return defer.fail( return defer.fail(
SynapseError( SynapseError(
400, 400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
"Not signed by %s" % (server_name,),
Codes.UNAUTHORIZED,
) )
) )
logger.debug("Verifying for %s with key_ids %s", logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
server_name, key_ids)
# add the key request to the queue, but don't start it off yet. # add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred(), server_name, key_ids, json_object, defer.Deferred()
) )
verify_requests.append(verify_request) verify_requests.append(verify_request)
@ -179,15 +174,13 @@ class Keyring(object):
# any other lookups until we have finished. # any other lookups until we have finished.
# The deferreds are called with no logcontext. # The deferreds are called with no logcontext.
server_to_deferred = { server_to_deferred = {
rq.server_name: defer.Deferred() rq.server_name: defer.Deferred() for rq in verify_requests
for rq in verify_requests
} }
# We want to wait for any previous lookups to complete before # We want to wait for any previous lookups to complete before
# proceeding. # proceeding.
yield self.wait_for_previous_lookups( yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests], [rq.server_name for rq in verify_requests], server_to_deferred
server_to_deferred,
) )
# Actually start fetching keys. # Actually start fetching keys.
@ -216,9 +209,7 @@ class Keyring(object):
return res return res
for verify_request in verify_requests: for verify_request in verify_requests:
verify_request.deferred.addBoth( verify_request.deferred.addBoth(remove_deferreds, verify_request)
remove_deferreds, verify_request,
)
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@ -248,7 +239,8 @@ class Keyring(object):
break break
logger.info( logger.info(
"Waiting for existing lookups for %s to complete [loop %i]", "Waiting for existing lookups for %s to complete [loop %i]",
[w[0] for w in wait_on], loop_count, [w[0] for w in wait_on],
loop_count,
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on)) yield defer.DeferredList((w[1] for w in wait_on))
@ -335,13 +327,14 @@ class Keyring(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
for verify_request in requests_missing_keys: for verify_request in requests_missing_keys:
verify_request.deferred.errback(SynapseError( verify_request.deferred.errback(
SynapseError(
401, 401,
"No key for %s with id %s" % ( "No key for %s with id %s"
verify_request.server_name, verify_request.key_ids, % (verify_request.server_name, verify_request.key_ids),
),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) )
)
def on_err(err): def on_err(err):
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -383,25 +376,26 @@ class Keyring(object):
) )
defer.returnValue(result) defer.returnValue(result)
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning("Key lookup failed from %r: %s", perspective_name, e)
"Key lookup failed from %r: %s", perspective_name, e,
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Unable to get key from %r: %s %s", "Unable to get key from %r: %s %s",
perspective_name, perspective_name,
type(e).__name__, str(e), type(e).__name__,
str(e),
) )
defer.returnValue({}) defer.returnValue({})
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background(get_key, p_name, p_keys) run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError)
)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@ -412,32 +406,30 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids): def get_keys_from_server(self, server_name_and_key_ids):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background( run_in_background(
self.get_server_verify_key_v2_direct, self.get_server_verify_key_v2_direct, server_name, key_ids
server_name,
key_ids,
) )
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError)
)
merged = {} merged = {}
for result in results: for result in results:
merged.update(result) merged.update(result)
defer.returnValue({ defer.returnValue(
server_name: keys {server_name: keys for server_name, keys in merged.items() if keys}
for server_name, keys in merged.items() )
if keys
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, def get_server_verify_key_v2_indirect(
perspective_name, self, server_names_and_key_ids, perspective_name, perspective_keys
perspective_keys): ):
# TODO(mark): Set the minimum_valid_until_ts to that needed by # TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating # the events being validated or the current time if validating
# an incoming request. # an incoming request.
@ -448,9 +440,7 @@ class Keyring(object):
data={ data={
u"server_keys": { u"server_keys": {
server_name: { server_name: {
key_id: { key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
u"minimum_valid_until_ts": 0
} for key_id in key_ids
} }
for server_name, key_ids in server_names_and_key_ids for server_name, key_ids in server_names_and_key_ids
} }
@ -458,21 +448,19 @@ class Keyring(object):
long_retries=True, long_retries=True,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from( raise_from(KeyLookupError("Failed to connect to remote server"), e)
KeyLookupError("Failed to connect to remote server"), e,
)
except HttpResponseException as e: except HttpResponseException as e:
raise_from( raise_from(KeyLookupError("Remote server returned an error"), e)
KeyLookupError("Remote server returned an error"), e,
)
keys = {} keys = {}
responses = query_response["server_keys"] responses = query_response["server_keys"]
for response in responses: for response in responses:
if (u"signatures" not in response if (
or perspective_name not in response[u"signatures"]): u"signatures" not in response
or perspective_name not in response[u"signatures"]
):
raise KeyLookupError( raise KeyLookupError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
@ -482,9 +470,7 @@ class Keyring(object):
for key_id in response[u"signatures"][perspective_name]: for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys: if key_id in perspective_keys:
verify_signed_json( verify_signed_json(
response, response, perspective_name, perspective_keys[key_id]
perspective_name,
perspective_keys[key_id]
) )
verified = True verified = True
@ -494,7 +480,7 @@ class Keyring(object):
" known key, signed with: %r, known keys: %r", " known key, signed with: %r, known keys: %r",
perspective_name, perspective_name,
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys),
) )
raise KeyLookupError( raise KeyLookupError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
@ -508,7 +494,8 @@ class Keyring(object):
keys.setdefault(server_name, {}).update(processed_response) keys.setdefault(server_name, {}).update(processed_response)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background( run_in_background(
self.store_keys, self.store_keys,
@ -518,8 +505,9 @@ class Keyring(object):
) )
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError)
)
defer.returnValue(keys) defer.returnValue(keys)
@ -534,26 +522,26 @@ class Keyring(object):
try: try:
response = yield self.client.get_json( response = yield self.client.get_json(
destination=server_name, destination=server_name,
path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True, ignore_backoff=True,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from( raise_from(KeyLookupError("Failed to connect to remote server"), e)
KeyLookupError("Failed to connect to remote server"), e,
)
except HttpResponseException as e: except HttpResponseException as e:
raise_from( raise_from(KeyLookupError("Remote server returned an error"), e)
KeyLookupError("Remote server returned an error"), e,
)
if (u"signatures" not in response if (
or server_name not in response[u"signatures"]): u"signatures" not in response
or server_name not in response[u"signatures"]
):
raise KeyLookupError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name: if response["server_name"] != server_name:
raise KeyLookupError("Expected a response for server %r not %r" % ( raise KeyLookupError(
server_name, response["server_name"] "Expected a response for server %r not %r"
)) % (server_name, response["server_name"])
)
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@ -564,16 +552,12 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield self.store_keys( yield self.store_keys(
server_name=server_name, server_name=server_name, from_server=server_name, verify_keys=keys
from_server=server_name,
verify_keys=keys,
) )
defer.returnValue({server_name: keys}) defer.returnValue({server_name: keys})
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response( def process_v2_response(self, from_server, response_json, requested_ids=[]):
self, from_server, response_json, requested_ids=[],
):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from This is used to parse either the entirety of the response from
@ -627,20 +611,13 @@ class Keyring(object):
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise KeyLookupError( raise KeyLookupError(
"Key response must include verification keys for all" "Key response must include verification keys for all" " signatures"
" signatures"
) )
if key_id in verify_keys: if key_id in verify_keys:
verify_signed_json( verify_signed_json(response_json, server_name, verify_keys[key_id])
response_json,
server_name,
verify_keys[key_id]
)
signed_key_json = sign_json( signed_key_json = sign_json(
response_json, response_json, self.config.server_name, self.config.signing_key[0]
self.config.server_name,
self.config.signing_key[0],
) )
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
@ -653,7 +630,8 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background( run_in_background(
self.store.store_server_keys_json, self.store.store_server_keys_json,
@ -667,7 +645,8 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError)
)
defer.returnValue(response_keys) defer.returnValue(response_keys)
@ -681,16 +660,21 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(defer.gatherResults( return logcontext.make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background( run_in_background(
self.store.store_server_verify_key, self.store.store_server_verify_key,
server_name, server_name, key.time_added, key server_name,
server_name,
key.time_added,
key,
) )
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -713,17 +697,19 @@ def _handle_key_deferred(verify_request):
except KeyLookupError as e: except KeyLookupError as e:
logger.warn( logger.warn(
"Failed to download keys for %s: %s %s", "Failed to download keys for %s: %s %s",
server_name, type(e).__name__, str(e), server_name,
type(e).__name__,
str(e),
) )
raise SynapseError( raise SynapseError(
502, 502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
) )
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Got Exception when downloading keys for %s: %s %s", "Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e), server_name,
type(e).__name__,
str(e),
) )
raise SynapseError( raise SynapseError(
401, 401,
@ -733,22 +719,24 @@ def _handle_key_deferred(verify_request):
json_object = verify_request.json_object json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % ( logger.debug(
key_id, verify_key.alg, verify_key.version, server_name, "Got key %s %s:%s for server %s, verifying"
)) % (key_id, verify_key.alg, verify_key.version, server_name)
)
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e: except SignatureVerifyException as e:
logger.debug( logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s", "Error verifying signature for %s:%s:%s with key %s: %s",
server_name, verify_key.alg, verify_key.version, server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key), encode_verify_key_base64(verify_key),
str(e), str(e),
) )
raise SynapseError( raise SynapseError(
401, 401,
"Invalid signature for server %s with key %s:%s: %s" % ( "Invalid signature for server %s with key %s:%s: %s"
server_name, verify_key.alg, verify_key.version, str(e), % (server_name, verify_key.alg, verify_key.version, str(e)),
),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )