mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-01 12:06:08 -04:00
Enforce validity period on server_keys for fed requests. (#5321)
When handling incoming federation requests, make sure that we have an up-to-date copy of the signing key. We do not yet enforce the validity period for event signatures.
This commit is contained in:
parent
fe2294ec8d
commit
fec2dcb1a5
6 changed files with 228 additions and 88 deletions
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import six
|
||||
from six import raise_from
|
||||
|
@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
|
|||
|
||||
json_object(dict): The JSON object to verify.
|
||||
|
||||
minimum_valid_until_ts (int): time at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||
|
@ -82,7 +86,8 @@ class VerifyKeyRequest(object):
|
|||
server_name = attr.ib()
|
||||
key_ids = attr.ib()
|
||||
json_object = attr.ib()
|
||||
deferred = attr.ib()
|
||||
minimum_valid_until_ts = attr.ib()
|
||||
deferred = attr.ib(default=attr.Factory(defer.Deferred))
|
||||
|
||||
|
||||
class KeyLookupError(ValueError):
|
||||
|
@ -90,14 +95,16 @@ class KeyLookupError(ValueError):
|
|||
|
||||
|
||||
class Keyring(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs, key_fetchers=None):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._key_fetchers = (
|
||||
StoreKeyFetcher(hs),
|
||||
PerspectivesKeyFetcher(hs),
|
||||
ServerKeyFetcher(hs),
|
||||
)
|
||||
if key_fetchers is None:
|
||||
key_fetchers = (
|
||||
StoreKeyFetcher(hs),
|
||||
PerspectivesKeyFetcher(hs),
|
||||
ServerKeyFetcher(hs),
|
||||
)
|
||||
self._key_fetchers = key_fetchers
|
||||
|
||||
# map from server name to Deferred. Has an entry for each server with
|
||||
# an ongoing key download; the Deferred completes once the download
|
||||
|
@ -106,9 +113,25 @@ class Keyring(object):
|
|||
# These are regular, logcontext-agnostic Deferreds.
|
||||
self.key_downloads = {}
|
||||
|
||||
def verify_json_for_server(self, server_name, json_object):
|
||||
def verify_json_for_server(self, server_name, json_object, validity_time):
|
||||
"""Verify that a JSON object has been signed by a given server
|
||||
|
||||
Args:
|
||||
server_name (str): name of the server which must have signed this object
|
||||
|
||||
json_object (dict): object to be checked
|
||||
|
||||
validity_time (int): timestamp at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
Returns:
|
||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||
errbacks with an error
|
||||
"""
|
||||
req = server_name, json_object, validity_time
|
||||
|
||||
return logcontext.make_deferred_yieldable(
|
||||
self.verify_json_objects_for_server([(server_name, json_object)])[0]
|
||||
self.verify_json_objects_for_server((req,))[0]
|
||||
)
|
||||
|
||||
def verify_json_objects_for_server(self, server_and_json):
|
||||
|
@ -116,10 +139,12 @@ class Keyring(object):
|
|||
necessary.
|
||||
|
||||
Args:
|
||||
server_and_json (list): List of pairs of (server_name, json_object)
|
||||
server_and_json (iterable[Tuple[str, dict, int]):
|
||||
Iterable of triplets of (server_name, json_object, validity_time)
|
||||
validity_time is a timestamp at which the signing key must be valid.
|
||||
|
||||
Returns:
|
||||
List<Deferred>: for each input pair, a deferred indicating success
|
||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
|
@ -128,12 +153,12 @@ class Keyring(object):
|
|||
verify_requests = []
|
||||
handle = preserve_fn(_handle_key_deferred)
|
||||
|
||||
def process(server_name, json_object):
|
||||
def process(server_name, json_object, validity_time):
|
||||
"""Process an entry in the request list
|
||||
|
||||
Given a (server_name, json_object) pair from the request list,
|
||||
adds a key request to verify_requests, and returns a deferred which will
|
||||
complete or fail (in the sentinel context) when verification completes.
|
||||
Given a (server_name, json_object, validity_time) triplet from the request
|
||||
list, adds a key request to verify_requests, and returns a deferred which
|
||||
will complete or fail (in the sentinel context) when verification completes.
|
||||
"""
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
|
||||
|
@ -148,7 +173,7 @@ class Keyring(object):
|
|||
|
||||
# add the key request to the queue, but don't start it off yet.
|
||||
verify_request = VerifyKeyRequest(
|
||||
server_name, key_ids, json_object, defer.Deferred()
|
||||
server_name, key_ids, json_object, validity_time
|
||||
)
|
||||
verify_requests.append(verify_request)
|
||||
|
||||
|
@ -160,8 +185,8 @@ class Keyring(object):
|
|||
return handle(verify_request)
|
||||
|
||||
results = [
|
||||
process(server_name, json_object)
|
||||
for server_name, json_object in server_and_json
|
||||
process(server_name, json_object, validity_time)
|
||||
for server_name, json_object, validity_time in server_and_json
|
||||
]
|
||||
|
||||
if verify_requests:
|
||||
|
@ -298,8 +323,12 @@ class Keyring(object):
|
|||
verify_request.deferred.errback(
|
||||
SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s"
|
||||
% (verify_request.server_name, verify_request.key_ids),
|
||||
"No key for %s with ids in %s (min_validity %i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
)
|
||||
|
@ -323,18 +352,28 @@ class Keyring(object):
|
|||
Args:
|
||||
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
||||
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
|
||||
Any successfully-completed requests will be reomved from the list.
|
||||
Any successfully-completed requests will be removed from the list.
|
||||
"""
|
||||
# dict[str, set(str)]: keys to fetch for each server
|
||||
missing_keys = {}
|
||||
# dict[str, dict[str, int]]: keys to fetch.
|
||||
# server_name -> key_id -> min_valid_ts
|
||||
missing_keys = defaultdict(dict)
|
||||
|
||||
for verify_request in remaining_requests:
|
||||
# any completed requests should already have been removed
|
||||
assert not verify_request.deferred.called
|
||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
keys_for_server = missing_keys[verify_request.server_name]
|
||||
|
||||
results = yield fetcher.get_keys(missing_keys.items())
|
||||
for key_id in verify_request.key_ids:
|
||||
# If we have several requests for the same key, then we only need to
|
||||
# request that key once, but we should do so with the greatest
|
||||
# min_valid_until_ts of the requests, so that we can satisfy all of
|
||||
# the requests.
|
||||
keys_for_server[key_id] = max(
|
||||
keys_for_server.get(key_id, -1),
|
||||
verify_request.minimum_valid_until_ts
|
||||
)
|
||||
|
||||
results = yield fetcher.get_keys(missing_keys)
|
||||
|
||||
completed = list()
|
||||
for verify_request in remaining_requests:
|
||||
|
@ -344,25 +383,34 @@ class Keyring(object):
|
|||
# complete this VerifyKeyRequest.
|
||||
result_keys = results.get(server_name, {})
|
||||
for key_id in verify_request.key_ids:
|
||||
key = result_keys.get(key_id)
|
||||
if key:
|
||||
with PreserveLoggingContext():
|
||||
verify_request.deferred.callback(
|
||||
(server_name, key_id, key.verify_key)
|
||||
)
|
||||
completed.append(verify_request)
|
||||
break
|
||||
fetch_key_result = result_keys.get(key_id)
|
||||
if not fetch_key_result:
|
||||
# we didn't get a result for this key
|
||||
continue
|
||||
|
||||
if (
|
||||
fetch_key_result.valid_until_ts
|
||||
< verify_request.minimum_valid_until_ts
|
||||
):
|
||||
# key was not valid at this point
|
||||
continue
|
||||
|
||||
with PreserveLoggingContext():
|
||||
verify_request.deferred.callback(
|
||||
(server_name, key_id, fetch_key_result.verify_key)
|
||||
)
|
||||
completed.append(verify_request)
|
||||
break
|
||||
|
||||
remaining_requests.difference_update(completed)
|
||||
|
||||
|
||||
class KeyFetcher(object):
|
||||
def get_keys(self, server_name_and_key_ids):
|
||||
def get_keys(self, keys_to_fetch):
|
||||
"""
|
||||
Args:
|
||||
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
|
||||
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
||||
Note that the iterables may be iterated more than once.
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
||||
|
@ -378,13 +426,15 @@ class StoreKeyFetcher(KeyFetcher):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys(self, server_name_and_key_ids):
|
||||
def get_keys(self, keys_to_fetch):
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
keys_to_fetch = (
|
||||
(server_name, key_id)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
for key_id in key_ids
|
||||
for server_name, keys_for_server in keys_to_fetch.items()
|
||||
for key_id in keys_for_server.keys()
|
||||
)
|
||||
|
||||
res = yield self.store.get_server_verify_keys(keys_to_fetch)
|
||||
keys = {}
|
||||
for (server_name, key_id), key in res.items():
|
||||
|
@ -508,14 +558,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
self.perspective_servers = self.config.perspectives
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys(self, server_name_and_key_ids):
|
||||
def get_keys(self, keys_to_fetch):
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_key(perspective_name, perspective_keys):
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
server_name_and_key_ids, perspective_name, perspective_keys
|
||||
keys_to_fetch, perspective_name, perspective_keys
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except KeyLookupError as e:
|
||||
|
@ -549,13 +599,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_indirect(
|
||||
self, server_names_and_key_ids, perspective_name, perspective_keys
|
||||
self, keys_to_fetch, perspective_name, perspective_keys
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
|
||||
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
|
||||
perspective_name (str): name of the notary server to query for the keys
|
||||
|
||||
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
|
||||
notary server
|
||||
|
||||
|
@ -569,12 +621,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
"""
|
||||
logger.info(
|
||||
"Requesting keys %s from notary server %s",
|
||||
server_names_and_key_ids,
|
||||
keys_to_fetch.items(),
|
||||
perspective_name,
|
||||
)
|
||||
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
||||
# the events being validated or the current time if validating
|
||||
# an incoming request.
|
||||
|
||||
try:
|
||||
query_response = yield self.client.post_json(
|
||||
destination=perspective_name,
|
||||
|
@ -582,9 +632,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
data={
|
||||
u"server_keys": {
|
||||
server_name: {
|
||||
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
|
||||
key_id: {u"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
}
|
||||
for server_name, key_ids in server_names_and_key_ids
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
}
|
||||
},
|
||||
long_retries=True,
|
||||
|
@ -694,15 +745,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
self.client = hs.get_http_client()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys(self, server_name_and_key_ids):
|
||||
def get_keys(self, keys_to_fetch):
|
||||
"""see KeyFetcher.get_keys"""
|
||||
# TODO make this more resilient
|
||||
results = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self.get_server_verify_key_v2_direct, server_name, key_ids
|
||||
self.get_server_verify_key_v2_direct,
|
||||
server_name,
|
||||
server_keys.keys(),
|
||||
)
|
||||
for server_name, key_ids in server_name_and_key_ids
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
@ -721,6 +775,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
keys = {} # type: dict[str, FetchKeyResult]
|
||||
|
||||
for requested_key_id in key_ids:
|
||||
# we may have found this key as a side-effect of asking for another.
|
||||
if requested_key_id in keys:
|
||||
continue
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue