mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-19 04:21:34 -05:00
Rewrite the KeyRing (#10035)
This commit is contained in:
parent
3cf6b34b4e
commit
fc3d2dc269
1
changelog.d/10035.feature
Normal file
1
changelog.d/10035.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory.
|
@ -16,8 +16,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from collections import defaultdict
|
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import (
|
from signedjson.key import (
|
||||||
@ -44,17 +43,12 @@ from synapse.api.errors import (
|
|||||||
from synapse.config.key import TrustedKeyServer
|
from synapse.config.key import TrustedKeyServer
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import prune_event_dict
|
from synapse.events.utils import prune_event_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
PreserveLoggingContext,
|
|
||||||
make_deferred_yieldable,
|
|
||||||
preserve_fn,
|
|
||||||
run_in_background,
|
|
||||||
)
|
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.batching_queue import BatchingQueue
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -80,32 +74,19 @@ class VerifyJsonRequest:
|
|||||||
minimum_valid_until_ts: time at which we require the signing key to
|
minimum_valid_until_ts: time at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
request_name: The name of the request.
|
|
||||||
|
|
||||||
key_ids: The set of key_ids to that could be used to verify the JSON object
|
key_ids: The set of key_ids to that could be used to verify the JSON object
|
||||||
|
|
||||||
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
|
||||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
|
||||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
|
||||||
logcontext.
|
|
||||||
|
|
||||||
If we are unable to find a key which satisfies the request, the deferred
|
|
||||||
errbacks with an M_UNAUTHORIZED SynapseError.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server_name = attr.ib(type=str)
|
server_name = attr.ib(type=str)
|
||||||
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
||||||
minimum_valid_until_ts = attr.ib(type=int)
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
request_name = attr.ib(type=str)
|
|
||||||
key_ids = attr.ib(type=List[str])
|
key_ids = attr.ib(type=List[str])
|
||||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json_object(
|
def from_json_object(
|
||||||
server_name: str,
|
server_name: str,
|
||||||
json_object: JsonDict,
|
json_object: JsonDict,
|
||||||
minimum_valid_until_ms: int,
|
minimum_valid_until_ms: int,
|
||||||
request_name: str,
|
|
||||||
):
|
):
|
||||||
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
||||||
object for the given server.
|
object for the given server.
|
||||||
@ -115,7 +96,6 @@ class VerifyJsonRequest:
|
|||||||
server_name,
|
server_name,
|
||||||
lambda: json_object,
|
lambda: json_object,
|
||||||
minimum_valid_until_ms,
|
minimum_valid_until_ms,
|
||||||
request_name=request_name,
|
|
||||||
key_ids=key_ids,
|
key_ids=key_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -135,16 +115,48 @@ class VerifyJsonRequest:
|
|||||||
# memory than the Event object itself.
|
# memory than the Event object itself.
|
||||||
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
||||||
minimum_valid_until_ms,
|
minimum_valid_until_ms,
|
||||||
request_name=event.event_id,
|
|
||||||
key_ids=key_ids,
|
key_ids=key_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_fetch_key_request(self) -> "_FetchKeyRequest":
|
||||||
|
"""Create a key fetch request for all keys needed to satisfy the
|
||||||
|
verification request.
|
||||||
|
"""
|
||||||
|
return _FetchKeyRequest(
|
||||||
|
server_name=self.server_name,
|
||||||
|
minimum_valid_until_ts=self.minimum_valid_until_ts,
|
||||||
|
key_ids=self.key_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeyLookupError(ValueError):
|
class KeyLookupError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _FetchKeyRequest:
|
||||||
|
"""A request for keys for a given server.
|
||||||
|
|
||||||
|
We will continue to try and fetch until we have all the keys listed under
|
||||||
|
`key_ids` (with an appropriate `valid_until_ts` property) or we run out of
|
||||||
|
places to fetch keys from.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name: The name of the server that owns the keys.
|
||||||
|
minimum_valid_until_ts: The timestamp which the keys must be valid until.
|
||||||
|
key_ids: The IDs of the keys to attempt to fetch
|
||||||
|
"""
|
||||||
|
|
||||||
|
server_name = attr.ib(type=str)
|
||||||
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
|
key_ids = attr.ib(type=List[str])
|
||||||
|
|
||||||
|
|
||||||
class Keyring:
|
class Keyring:
|
||||||
|
"""Handles verifying signed JSON objects and fetching the keys needed to do
|
||||||
|
so.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||||
):
|
):
|
||||||
@ -158,22 +170,22 @@ class Keyring:
|
|||||||
)
|
)
|
||||||
self._key_fetchers = key_fetchers
|
self._key_fetchers = key_fetchers
|
||||||
|
|
||||||
# map from server name to Deferred. Has an entry for each server with
|
self._server_queue = BatchingQueue(
|
||||||
# an ongoing key download; the Deferred completes once the download
|
"keyring_server",
|
||||||
# completes.
|
clock=hs.get_clock(),
|
||||||
#
|
process_batch_callback=self._inner_fetch_key_requests,
|
||||||
# These are regular, logcontext-agnostic Deferreds.
|
) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
|
||||||
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
|
||||||
|
|
||||||
def verify_json_for_server(
|
async def verify_json_for_server(
|
||||||
self,
|
self,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
json_object: JsonDict,
|
json_object: JsonDict,
|
||||||
validity_time: int,
|
validity_time: int,
|
||||||
request_name: str,
|
) -> None:
|
||||||
) -> defer.Deferred:
|
|
||||||
"""Verify that a JSON object has been signed by a given server
|
"""Verify that a JSON object has been signed by a given server
|
||||||
|
|
||||||
|
Completes if the the object was correctly signed, otherwise raises.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name: name of the server which must have signed this object
|
server_name: name of the server which must have signed this object
|
||||||
|
|
||||||
@ -181,52 +193,45 @@ class Keyring:
|
|||||||
|
|
||||||
validity_time: timestamp at which we require the signing key to
|
validity_time: timestamp at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
request_name: an identifier for this json object (eg, an event id)
|
|
||||||
for logging.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
|
||||||
errbacks with an error
|
|
||||||
"""
|
"""
|
||||||
request = VerifyJsonRequest.from_json_object(
|
request = VerifyJsonRequest.from_json_object(
|
||||||
server_name,
|
server_name,
|
||||||
json_object,
|
json_object,
|
||||||
validity_time,
|
validity_time,
|
||||||
request_name,
|
|
||||||
)
|
)
|
||||||
requests = (request,)
|
return await self.process_request(request)
|
||||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
|
||||||
|
|
||||||
def verify_json_objects_for_server(
|
def verify_json_objects_for_server(
|
||||||
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
self, server_and_json: Iterable[Tuple[str, dict, int]]
|
||||||
) -> List[defer.Deferred]:
|
) -> List[defer.Deferred]:
|
||||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||||
necessary.
|
necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_and_json:
|
server_and_json:
|
||||||
Iterable of (server_name, json_object, validity_time, request_name)
|
Iterable of (server_name, json_object, validity_time)
|
||||||
tuples.
|
tuples.
|
||||||
|
|
||||||
validity_time is a timestamp at which the signing key must be
|
validity_time is a timestamp at which the signing key must be
|
||||||
valid.
|
valid.
|
||||||
|
|
||||||
request_name is an identifier for this json object (eg, an event id)
|
|
||||||
for logging.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||||
or failure to verify each json object's signature for the given
|
or failure to verify each json object's signature for the given
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
logcontext.
|
logcontext.
|
||||||
"""
|
"""
|
||||||
return self._verify_objects(
|
return [
|
||||||
VerifyJsonRequest.from_json_object(
|
run_in_background(
|
||||||
server_name, json_object, validity_time, request_name
|
self.process_request,
|
||||||
|
VerifyJsonRequest.from_json_object(
|
||||||
|
server_name,
|
||||||
|
json_object,
|
||||||
|
validity_time,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for server_name, json_object, validity_time, request_name in server_and_json
|
for server_name, json_object, validity_time in server_and_json
|
||||||
)
|
]
|
||||||
|
|
||||||
def verify_events_for_server(
|
def verify_events_for_server(
|
||||||
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
||||||
@ -252,321 +257,223 @@ class Keyring:
|
|||||||
server_name. The deferreds run their callbacks in the sentinel
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
logcontext.
|
logcontext.
|
||||||
"""
|
"""
|
||||||
return self._verify_objects(
|
return [
|
||||||
VerifyJsonRequest.from_event(server_name, event, validity_time)
|
run_in_background(
|
||||||
|
self.process_request,
|
||||||
|
VerifyJsonRequest.from_event(
|
||||||
|
server_name,
|
||||||
|
event,
|
||||||
|
validity_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
for server_name, event, validity_time in server_and_events
|
for server_name, event, validity_time in server_and_events
|
||||||
|
]
|
||||||
|
|
||||||
|
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
|
||||||
|
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
|
||||||
|
by the server, the signatures don't match or we failed to fetch the
|
||||||
|
necessary keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not verify_request.key_ids:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
f"Not signed by {verify_request.server_name}",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the keys we need to verify to the queue for retrieval. We queue
|
||||||
|
# up requests for the same server so we don't end up with many in flight
|
||||||
|
# requests for the same keys.
|
||||||
|
key_request = verify_request.to_fetch_key_request()
|
||||||
|
found_keys_by_server = await self._server_queue.add_to_queue(
|
||||||
|
key_request, key=verify_request.server_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def _verify_objects(
|
# Since we batch up requests the returned set of keys may contain keys
|
||||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
# from other servers, so we pull out only the ones we care about.s
|
||||||
) -> List[defer.Deferred]:
|
found_keys = found_keys_by_server.get(verify_request.server_name, {})
|
||||||
"""Does the work of verify_json_[objects_]for_server
|
|
||||||
|
|
||||||
|
# Verify each signature we got valid keys for, raising if we can't
|
||||||
|
# verify any of them.
|
||||||
|
verified = False
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
key_result = found_keys.get(key_id)
|
||||||
|
if not key_result:
|
||||||
|
continue
|
||||||
|
|
||||||
Args:
|
if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||||
verify_requests: Iterable of verification requests.
|
continue
|
||||||
|
|
||||||
Returns:
|
verify_key = key_result.verify_key
|
||||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
json_object = verify_request.get_json_object()
|
||||||
or failure to verify each json object's signature for the given
|
try:
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
verify_signed_json(
|
||||||
logcontext.
|
json_object,
|
||||||
"""
|
verify_request.server_name,
|
||||||
# a list of VerifyJsonRequests which are awaiting a key lookup
|
verify_key,
|
||||||
key_lookups = []
|
)
|
||||||
handle = preserve_fn(_handle_key_deferred)
|
verified = True
|
||||||
|
except SignatureVerifyException as e:
|
||||||
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
logger.debug(
|
||||||
"""Process an entry in the request list
|
"Error verifying signature for %s:%s:%s with key %s: %s",
|
||||||
|
verify_request.server_name,
|
||||||
Adds a key request to key_lookups, and returns a deferred which
|
verify_key.alg,
|
||||||
will complete or fail (in the sentinel context) when verification completes.
|
verify_key.version,
|
||||||
"""
|
encode_verify_key_base64(verify_key),
|
||||||
if not verify_request.key_ids:
|
str(e),
|
||||||
return defer.fail(
|
)
|
||||||
SynapseError(
|
raise SynapseError(
|
||||||
400,
|
401,
|
||||||
"Not signed by %s" % (verify_request.server_name,),
|
"Invalid signature for server %s with key %s:%s: %s"
|
||||||
Codes.UNAUTHORIZED,
|
% (
|
||||||
)
|
verify_request.server_name,
|
||||||
|
verify_key.alg,
|
||||||
|
verify_key.version,
|
||||||
|
str(e),
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
if not verified:
|
||||||
"Verifying %s for %s with key_ids %s, min_validity %i",
|
raise SynapseError(
|
||||||
verify_request.request_name,
|
401,
|
||||||
|
f"Failed to find any key to satisfy: {key_request}",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _inner_fetch_key_requests(
|
||||||
|
self, requests: List[_FetchKeyRequest]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
|
"""Processing function for the queue of `_FetchKeyRequest`."""
|
||||||
|
|
||||||
|
logger.debug("Starting fetch for %s", requests)
|
||||||
|
|
||||||
|
# First we need to deduplicate requests for the same key. We do this by
|
||||||
|
# taking the *maximum* requested `minimum_valid_until_ts` for each pair
|
||||||
|
# of server name/key ID.
|
||||||
|
server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]]
|
||||||
|
for request in requests:
|
||||||
|
by_server = server_to_key_to_ts.setdefault(request.server_name, {})
|
||||||
|
for key_id in request.key_ids:
|
||||||
|
existing_ts = by_server.get(key_id, 0)
|
||||||
|
by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
|
||||||
|
|
||||||
|
deduped_requests = [
|
||||||
|
_FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
|
||||||
|
for server_name, by_server in server_to_key_to_ts.items()
|
||||||
|
for key_id, minimum_valid_ts in by_server.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug("Deduplicated key requests to %s", deduped_requests)
|
||||||
|
|
||||||
|
# For each key we call `_inner_verify_request` which will handle
|
||||||
|
# fetching each key. Note these shouldn't throw if we fail to contact
|
||||||
|
# other servers etc.
|
||||||
|
results_per_request = await yieldable_gather_results(
|
||||||
|
self._inner_fetch_key_request,
|
||||||
|
deduped_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We now convert the returned list of results into a map from server
|
||||||
|
# name to key ID to FetchKeyResult, to return.
|
||||||
|
to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||||
|
for (request, results) in zip(deduped_requests, results_per_request):
|
||||||
|
to_return_by_server = to_return.setdefault(request.server_name, {})
|
||||||
|
for key_id, key_result in results.items():
|
||||||
|
existing = to_return_by_server.get(key_id)
|
||||||
|
if not existing or existing.valid_until_ts < key_result.valid_until_ts:
|
||||||
|
to_return_by_server[key_id] = key_result
|
||||||
|
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
async def _inner_fetch_key_request(
|
||||||
|
self, verify_request: _FetchKeyRequest
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
"""Attempt to fetch the given key by calling each key fetcher one by
|
||||||
|
one.
|
||||||
|
"""
|
||||||
|
logger.debug("Starting fetch for %s", verify_request)
|
||||||
|
|
||||||
|
found_keys: Dict[str, FetchKeyResult] = {}
|
||||||
|
missing_key_ids = set(verify_request.key_ids)
|
||||||
|
|
||||||
|
for fetcher in self._key_fetchers:
|
||||||
|
if not missing_key_ids:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug("Getting keys from %s for %s", fetcher, verify_request)
|
||||||
|
keys = await fetcher.get_keys(
|
||||||
verify_request.server_name,
|
verify_request.server_name,
|
||||||
verify_request.key_ids,
|
list(missing_key_ids),
|
||||||
verify_request.minimum_valid_until_ts,
|
verify_request.minimum_valid_until_ts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add the key request to the queue, but don't start it off yet.
|
for key_id, key in keys.items():
|
||||||
key_lookups.append(verify_request)
|
if not key:
|
||||||
|
|
||||||
# now run _handle_key_deferred, which will wait for the key request
|
|
||||||
# to complete and then do the verification.
|
|
||||||
#
|
|
||||||
# We want _handle_key_request to log to the right context, so we
|
|
||||||
# wrap it with preserve_fn (aka run_in_background)
|
|
||||||
return handle(verify_request)
|
|
||||||
|
|
||||||
results = [process(r) for r in verify_requests]
|
|
||||||
|
|
||||||
if key_lookups:
|
|
||||||
run_in_background(self._start_key_lookups, key_lookups)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def _start_key_lookups(
|
|
||||||
self, verify_requests: List[VerifyJsonRequest]
|
|
||||||
) -> None:
|
|
||||||
"""Sets off the key fetches for each verify request
|
|
||||||
|
|
||||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_requests:
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# map from server name to a set of outstanding request ids
|
|
||||||
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
request_id = id(verify_request)
|
|
||||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
|
||||||
|
|
||||||
# Wait for any previous lookups to complete before proceeding.
|
|
||||||
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
|
||||||
|
|
||||||
# take out a lock on each of the servers by sticking a Deferred in
|
|
||||||
# key_downloads
|
|
||||||
for server_name in server_to_request_ids.keys():
|
|
||||||
self.key_downloads[server_name] = defer.Deferred()
|
|
||||||
logger.debug("Got key lookup lock on %s", server_name)
|
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
|
||||||
# drop the lock by resolving the deferred in key_downloads.
|
|
||||||
def drop_server_lock(server_name):
|
|
||||||
d = self.key_downloads.pop(server_name)
|
|
||||||
d.callback(None)
|
|
||||||
|
|
||||||
def lookup_done(res, verify_request):
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
server_requests = server_to_request_ids[server_name]
|
|
||||||
server_requests.remove(id(verify_request))
|
|
||||||
|
|
||||||
# if there are no more requests for this server, we can drop the lock.
|
|
||||||
if not server_requests:
|
|
||||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
|
||||||
drop_server_lock(server_name)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
verify_request.key_ready.addBoth(lookup_done, verify_request)
|
|
||||||
|
|
||||||
# Actually start fetching keys.
|
|
||||||
self._get_server_verify_keys(verify_requests)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error starting key lookups")
|
|
||||||
|
|
||||||
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_names: list of servers which we want to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolves once all key lookups for the given servers have
|
|
||||||
completed. Follows the synapse rules of logcontext preservation.
|
|
||||||
"""
|
|
||||||
loop_count = 1
|
|
||||||
while True:
|
|
||||||
wait_on = [
|
|
||||||
(server_name, self.key_downloads[server_name])
|
|
||||||
for server_name in server_names
|
|
||||||
if server_name in self.key_downloads
|
|
||||||
]
|
|
||||||
if not wait_on:
|
|
||||||
break
|
|
||||||
logger.info(
|
|
||||||
"Waiting for existing lookups for %s to complete [loop %i]",
|
|
||||||
[w[0] for w in wait_on],
|
|
||||||
loop_count,
|
|
||||||
)
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
await defer.DeferredList((w[1] for w in wait_on))
|
|
||||||
|
|
||||||
loop_count += 1
|
|
||||||
|
|
||||||
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
|
||||||
"""Tries to find at least one key for each verify request
|
|
||||||
|
|
||||||
For each verify_request, verify_request.key_ready is called back with
|
|
||||||
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
|
||||||
with a SynapseError if none of the keys are found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_requests: list of verify requests
|
|
||||||
"""
|
|
||||||
|
|
||||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
|
||||||
|
|
||||||
async def do_iterations():
|
|
||||||
try:
|
|
||||||
with Measure(self.clock, "get_server_verify_keys"):
|
|
||||||
for f in self._key_fetchers:
|
|
||||||
if not remaining_requests:
|
|
||||||
return
|
|
||||||
await self._attempt_key_fetches_with_fetcher(
|
|
||||||
f, remaining_requests
|
|
||||||
)
|
|
||||||
|
|
||||||
# look for any requests which weren't satisfied
|
|
||||||
while remaining_requests:
|
|
||||||
verify_request = remaining_requests.pop()
|
|
||||||
rq_str = (
|
|
||||||
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
|
|
||||||
% (
|
|
||||||
verify_request.server_name,
|
|
||||||
verify_request.key_ids,
|
|
||||||
verify_request.minimum_valid_until_ts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we run the errback immediately, it may cancel our
|
|
||||||
# loggingcontext while we are still in it, so instead we
|
|
||||||
# schedule it for the next time round the reactor.
|
|
||||||
#
|
|
||||||
# (this also ensures that we don't get a stack overflow if we
|
|
||||||
# has a massive queue of lookups waiting for this server).
|
|
||||||
self.clock.call_later(
|
|
||||||
0,
|
|
||||||
verify_request.key_ready.errback,
|
|
||||||
SynapseError(
|
|
||||||
401,
|
|
||||||
"Failed to find any key to satisfy %s" % (rq_str,),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
# we don't really expect to get here, because any errors should already
|
|
||||||
# have been caught and logged. But if we do, let's log the error and make
|
|
||||||
# sure that all of the deferreds are resolved.
|
|
||||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
if not verify_request.key_ready.called:
|
|
||||||
verify_request.key_ready.errback(err)
|
|
||||||
|
|
||||||
run_in_background(do_iterations)
|
|
||||||
|
|
||||||
async def _attempt_key_fetches_with_fetcher(
|
|
||||||
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
|
||||||
):
|
|
||||||
"""Use a key fetcher to attempt to satisfy some key requests
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fetcher: fetcher to use to fetch the keys
|
|
||||||
remaining_requests: outstanding key requests.
|
|
||||||
Any successfully-completed requests will be removed from the list.
|
|
||||||
"""
|
|
||||||
# The keys to fetch.
|
|
||||||
# server_name -> key_id -> min_valid_ts
|
|
||||||
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
|
||||||
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
# any completed requests should already have been removed
|
|
||||||
assert not verify_request.key_ready.called
|
|
||||||
keys_for_server = missing_keys[verify_request.server_name]
|
|
||||||
|
|
||||||
for key_id in verify_request.key_ids:
|
|
||||||
# If we have several requests for the same key, then we only need to
|
|
||||||
# request that key once, but we should do so with the greatest
|
|
||||||
# min_valid_until_ts of the requests, so that we can satisfy all of
|
|
||||||
# the requests.
|
|
||||||
keys_for_server[key_id] = max(
|
|
||||||
keys_for_server.get(key_id, -1),
|
|
||||||
verify_request.minimum_valid_until_ts,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await fetcher.get_keys(missing_keys)
|
|
||||||
|
|
||||||
completed = []
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
|
|
||||||
# see if any of the keys we got this time are sufficient to
|
|
||||||
# complete this VerifyJsonRequest.
|
|
||||||
result_keys = results.get(server_name, {})
|
|
||||||
for key_id in verify_request.key_ids:
|
|
||||||
fetch_key_result = result_keys.get(key_id)
|
|
||||||
if not fetch_key_result:
|
|
||||||
# we didn't get a result for this key
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
# If we already have a result for the given key ID we keep the
|
||||||
fetch_key_result.valid_until_ts
|
# one with the highest `valid_until_ts`.
|
||||||
< verify_request.minimum_valid_until_ts
|
existing_key = found_keys.get(key_id)
|
||||||
):
|
if existing_key:
|
||||||
# key was not valid at this point
|
if key.valid_until_ts <= existing_key.valid_until_ts:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we have a valid key for this request. If we run the callback
|
# We always store the returned key even if it doesn't the
|
||||||
# immediately, it may cancel our loggingcontext while we are still in
|
# `minimum_valid_until_ts` requirement, as some verification
|
||||||
# it, so instead we schedule it for the next time round the reactor.
|
# requests may still be able to be satisfied by it.
|
||||||
#
|
#
|
||||||
# (this also ensures that we don't get a stack overflow if we had
|
# We still keep looking for the key from other fetchers in that
|
||||||
# a massive queue of lookups waiting for this server).
|
# case though.
|
||||||
logger.debug(
|
found_keys[key_id] = key
|
||||||
"Found key %s:%s for %s",
|
|
||||||
server_name,
|
|
||||||
key_id,
|
|
||||||
verify_request.request_name,
|
|
||||||
)
|
|
||||||
self.clock.call_later(
|
|
||||||
0,
|
|
||||||
verify_request.key_ready.callback,
|
|
||||||
(server_name, key_id, fetch_key_result.verify_key),
|
|
||||||
)
|
|
||||||
completed.append(verify_request)
|
|
||||||
break
|
|
||||||
|
|
||||||
remaining_requests.difference_update(completed)
|
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
missing_key_ids.discard(key_id)
|
||||||
|
|
||||||
|
return found_keys
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher(metaclass=abc.ABCMeta):
|
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||||
@abc.abstractmethod
|
def __init__(self, hs: "HomeServer"):
|
||||||
async def get_keys(
|
self._queue = BatchingQueue(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self.__class__.__name__, hs.get_clock(), self._fetch_keys
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
)
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
keys_to_fetch:
|
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
|
||||||
|
|
||||||
Returns:
|
async def get_keys(
|
||||||
Map from server_name -> key_id -> FetchKeyResult
|
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
"""
|
) -> Dict[str, FetchKeyResult]:
|
||||||
raise NotImplementedError
|
results = await self._queue.add_to_queue(
|
||||||
|
_FetchKeyRequest(
|
||||||
|
server_name=server_name,
|
||||||
|
key_ids=key_ids,
|
||||||
|
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results.get(server_name, {})
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _fetch_keys(
|
||||||
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StoreKeyFetcher(KeyFetcher):
|
class StoreKeyFetcher(KeyFetcher):
|
||||||
"""KeyFetcher impl which fetches keys from our data store"""
|
"""KeyFetcher impl which fetches keys from our data store"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def get_keys(
|
async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
|
||||||
"""see KeyFetcher.get_keys"""
|
|
||||||
|
|
||||||
key_ids_to_fetch = (
|
key_ids_to_fetch = (
|
||||||
(server_name, key_id)
|
(queue_value.server_name, key_id)
|
||||||
for server_name, keys_for_server in keys_to_fetch.items()
|
for queue_value in keys_to_fetch
|
||||||
for key_id in keys_for_server.keys()
|
for key_id in queue_value.key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||||
@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
|
|||||||
|
|
||||||
class BaseV2KeyFetcher(KeyFetcher):
|
class BaseV2KeyFetcher(KeyFetcher):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
self.key_servers = self.config.key_servers
|
self.key_servers = self.config.key_servers
|
||||||
|
|
||||||
async def get_keys(
|
async def _fetch_keys(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher._fetch_keys"""
|
||||||
|
|
||||||
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||||
try:
|
try:
|
||||||
@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
return union_of_keys
|
return union_of_keys
|
||||||
|
|
||||||
async def get_server_verify_key_v2_indirect(
|
async def get_server_verify_key_v2_indirect(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch:
|
keys_to_fetch:
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
the keys to be fetched.
|
||||||
|
|
||||||
key_server: notary server to query for the keys
|
key_server: notary server to query for the keys
|
||||||
|
|
||||||
@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
perspective_name = key_server.server_name
|
perspective_name = key_server.server_name
|
||||||
logger.info(
|
logger.info(
|
||||||
"Requesting keys %s from notary server %s",
|
"Requesting keys %s from notary server %s",
|
||||||
keys_to_fetch.items(),
|
keys_to_fetch,
|
||||||
perspective_name,
|
perspective_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
path="/_matrix/key/v2/query",
|
path="/_matrix/key/v2/query",
|
||||||
data={
|
data={
|
||||||
"server_keys": {
|
"server_keys": {
|
||||||
server_name: {
|
queue_value.server_name: {
|
||||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
key_id: {
|
||||||
for key_id, min_valid_ts in server_keys.items()
|
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
|
||||||
|
}
|
||||||
|
for key_id in queue_value.key_ids
|
||||||
}
|
}
|
||||||
for server_name, server_keys in keys_to_fetch.items()
|
for queue_value in keys_to_fetch
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
|
|
||||||
async def get_keys(
|
async def get_keys(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
results = await self._queue.add_to_queue(
|
||||||
|
_FetchKeyRequest(
|
||||||
|
server_name=server_name,
|
||||||
|
key_ids=key_ids,
|
||||||
|
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||||
|
),
|
||||||
|
key=server_name,
|
||||||
|
)
|
||||||
|
return results.get(server_name, {})
|
||||||
|
|
||||||
|
async def _fetch_keys(
|
||||||
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
|
||||||
server_name, key_ids = key_to_fetch_item
|
server_name = key_to_fetch_item.server_name
|
||||||
|
key_ids = key_to_fetch_item.key_ids
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||||
results[server_name] = keys
|
results[server_name] = keys
|
||||||
@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||||
|
|
||||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
await yieldable_gather_results(get_key, keys_to_fetch)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_server_verify_key_v2_direct(
|
async def get_server_verify_key_v2_direct(
|
||||||
@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
|
||||||
"""Waits for the key to become available, and then performs a verification
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_request:
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SynapseError if there was a problem performing the verification
|
|
||||||
"""
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
_, key_id, verify_key = await verify_request.key_ready
|
|
||||||
|
|
||||||
json_object = verify_request.get_json_object()
|
|
||||||
|
|
||||||
try:
|
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
|
||||||
except SignatureVerifyException as e:
|
|
||||||
logger.debug(
|
|
||||||
"Error verifying signature for %s:%s:%s with key %s: %s",
|
|
||||||
server_name,
|
|
||||||
verify_key.alg,
|
|
||||||
verify_key.version,
|
|
||||||
encode_verify_key_base64(verify_key),
|
|
||||||
str(e),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"Invalid signature for server %s with key %s:%s: %s"
|
|
||||||
% (server_name, verify_key.alg, verify_key.version, str(e)),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
@ -152,7 +152,9 @@ class Authenticator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.keyring.verify_json_for_server(
|
await self.keyring.verify_json_for_server(
|
||||||
origin, json_request, now, "Incoming request"
|
origin,
|
||||||
|
json_request,
|
||||||
|
now,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Request from %s", origin)
|
logger.debug("Request from %s", origin)
|
||||||
|
@ -108,7 +108,9 @@ class GroupAttestationSigning:
|
|||||||
|
|
||||||
assert server_name is not None
|
assert server_name is not None
|
||||||
await self.keyring.verify_json_for_server(
|
await self.keyring.verify_json_for_server(
|
||||||
server_name, attestation, now, "Group attestation"
|
server_name,
|
||||||
|
attestation,
|
||||||
|
now,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
|
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
|
||||||
|
@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
|
|||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
# If there is a cache miss, request the missing keys, then recurse (and
|
# If there is a cache miss, request the missing keys, then recurse (and
|
||||||
# ensure the result is sent).
|
# ensure the result is sent).
|
||||||
if cache_misses and query_remote_on_cache_miss:
|
if cache_misses and query_remote_on_cache_miss:
|
||||||
await self.fetcher.get_keys(cache_misses)
|
await yieldable_gather_results(
|
||||||
|
lambda t: self.fetcher.get_keys(*t),
|
||||||
|
(
|
||||||
|
(server_name, list(keys), 0)
|
||||||
|
for server_name, keys in cache_misses.items()
|
||||||
|
),
|
||||||
|
)
|
||||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||||
else:
|
else:
|
||||||
signed_keys = []
|
signed_keys = []
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
@ -21,7 +22,6 @@ import signedjson.sign
|
|||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey
|
||||||
from signedjson.key import encode_verify_key_base64, get_verify_key
|
from signedjson.key import encode_verify_key_base64, get_verify_key
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
# deferred completes.
|
# deferred completes.
|
||||||
first_lookup_deferred = Deferred()
|
first_lookup_deferred = Deferred()
|
||||||
|
|
||||||
async def first_lookup_fetch(keys_to_fetch):
|
async def first_lookup_fetch(
|
||||||
self.assertEquals(current_context().request.id, "context_11")
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
# self.assertEquals(current_context().request.id, "context_11")
|
||||||
|
self.assertEqual(server_name, "server10")
|
||||||
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 0)
|
||||||
|
|
||||||
await make_deferred_yieldable(first_lookup_deferred)
|
await make_deferred_yieldable(first_lookup_deferred)
|
||||||
return {
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||||
"server10": {
|
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
||||||
|
|
||||||
async def first_lookup():
|
async def first_lookup():
|
||||||
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
[("server10", json1, 0), ("server11", {}, 0)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# the unsigned json should be rejected pretty quickly
|
# the unsigned json should be rejected pretty quickly
|
||||||
@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
d0 = ensureDeferred(first_lookup())
|
d0 = ensureDeferred(first_lookup())
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
mock_fetcher.get_keys.assert_called_once()
|
mock_fetcher.get_keys.assert_called_once()
|
||||||
|
|
||||||
# a second request for a server with outstanding requests
|
# a second request for a server with outstanding requests
|
||||||
# should block rather than start a second call
|
# should block rather than start a second call
|
||||||
|
|
||||||
async def second_lookup_fetch(keys_to_fetch):
|
async def second_lookup_fetch(
|
||||||
self.assertEquals(current_context().request.id, "context_12")
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server10": {
|
# self.assertEquals(current_context().request.id, "context_12")
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher.get_keys.reset_mock()
|
mock_fetcher.get_keys.reset_mock()
|
||||||
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
||||||
@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
async def second_lookup():
|
async def second_lookup():
|
||||||
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test")]
|
[
|
||||||
|
(
|
||||||
|
"server10",
|
||||||
|
json1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
second_lookup_state[0] = 1
|
second_lookup_state[0] = 1
|
||||||
@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
# should fail immediately on an unsigned object
|
# should fail immediately on an unsigned object
|
||||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
d = kr.verify_json_for_server("server9", {}, 0)
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# should succeed on a signed object
|
# should succeed on a signed object
|
||||||
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
|
d = kr.verify_json_for_server("server9", json1, 500)
|
||||||
# self.assertFalse(d.called)
|
# self.assertFalse(d.called)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
# should fail immediately on an unsigned object
|
# should fail immediately on an unsigned object
|
||||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
d = kr.verify_json_for_server("server9", {}, 0)
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
||||||
# as it tries to refetch the keys and fails.
|
# as it tries to refetch the keys and fails.
|
||||||
d = _verify_json_for_server(
|
d = kr.verify_json_for_server("server9", json1, 500)
|
||||||
kr, "server9", json1, 500, "test signed non-zero min"
|
|
||||||
)
|
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# We expect the keyring tried to refetch the key once.
|
# We expect the keyring tried to refetch the key once.
|
||||||
mock_fetcher.get_keys.assert_called_once_with(
|
mock_fetcher.get_keys.assert_called_once_with(
|
||||||
{"server9": {get_key_id(key1): 500}}
|
"server9", [get_key_id(key1)], 500
|
||||||
)
|
)
|
||||||
|
|
||||||
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
||||||
d = _verify_json_for_server(
|
d = kr.verify_json_for_server(
|
||||||
kr, "server9", json1, 0, "test signed with zero min"
|
"server9",
|
||||||
|
json1,
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
"""Two requests for the same key should be deduped."""
|
"""Two requests for the same key should be deduped."""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
async def get_keys(keys_to_fetch):
|
async def get_keys(
|
||||||
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
# there should only be one request object (with the max validity)
|
# there should only be one request object (with the max validity)
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
self.assertEqual(server_name, "server1")
|
||||||
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
|
|
||||||
return {
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||||
"server1": {
|
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher = Mock()
|
mock_fetcher = Mock()
|
||||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||||
@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
# the first request should succeed; the second should fail because the key
|
# the first request should succeed; the second should fail because the key
|
||||||
# has expired
|
# has expired
|
||||||
results = kr.verify_json_objects_for_server(
|
results = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
|
[
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
500,
|
||||||
|
),
|
||||||
|
("server1", json1, 1500),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.get_success(results[0])
|
self.get_success(results[0])
|
||||||
@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
async def get_keys1(keys_to_fetch):
|
async def get_keys1(
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
self.assertEqual(server_name, "server1")
|
||||||
}
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
||||||
|
|
||||||
async def get_keys2(keys_to_fetch):
|
async def get_keys2(
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server1": {
|
self.assertEqual(server_name, "server1")
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
}
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
}
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||||
|
|
||||||
mock_fetcher1 = Mock()
|
mock_fetcher1 = Mock()
|
||||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||||
@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
signedjson.sign.sign_json(json1, "server1", key1)
|
signedjson.sign.sign_json(json1, "server1", key1)
|
||||||
|
|
||||||
results = kr.verify_json_objects_for_server(
|
results = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
|
[
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
1200,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
1500,
|
||||||
|
),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.get_success(results[0])
|
self.get_success(results[0])
|
||||||
@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.http_client.get_json.side_effect = get_json
|
self.http_client.get_json.side_effect = get_json
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
# change the server name: the result should be ignored
|
# change the server name: the result should be ignored
|
||||||
response["server_name"] = "OTHER_SERVER"
|
response["server_name"] = "OTHER_SERVER"
|
||||||
|
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
self.assertEqual(keys, {})
|
self.assertEqual(keys, {})
|
||||||
|
|
||||||
|
|
||||||
@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
self.assertIn(testverifykey_id, keys)
|
||||||
self.assertIn(SERVER_NAME, keys)
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
self.assertIn(testverifykey_id, keys)
|
||||||
self.assertIn(SERVER_NAME, keys)
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
def get_key_from_perspectives(response):
|
def get_key_from_perspectives(response):
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs)
|
fetcher = PerspectivesKeyFetcher(self.hs)
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
return self.get_success(fetcher.get_keys(keys_to_fetch))
|
return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
|
|
||||||
# start with a valid response so we can check we are testing the right thing
|
# start with a valid response so we can check we are testing the right thing
|
||||||
response = build_response()
|
response = build_response()
|
||||||
keys = get_key_from_perspectives(response)
|
keys = get_key_from_perspectives(response)
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
k = keys[testverifykey_id]
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
|
|
||||||
# remove the perspectives server's signature
|
# remove the perspectives server's signature
|
||||||
@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||||||
def get_key_id(key):
|
def get_key_id(key):
|
||||||
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
||||||
return "%s:%s" % (key.alg, key.version)
|
return "%s:%s" % (key.alg, key.version)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def run_in_context(f, *args, **kwargs):
|
|
||||||
with LoggingContext("testctx"):
|
|
||||||
rv = yield f(*args, **kwargs)
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_json_for_server(kr, *args):
|
|
||||||
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
|
||||||
with the patched defer.inlineCallbacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def v():
|
|
||||||
rv1 = yield kr.verify_json_for_server(*args)
|
|
||||||
return rv1
|
|
||||||
|
|
||||||
return run_in_context(v)
|
|
||||||
|
@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
|||||||
keyid = "ed25519:%s" % (testkey.version,)
|
keyid = "ed25519:%s" % (testkey.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({"targetserver": {keyid: 1000}})
|
d = fetcher.get_keys("targetserver", [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn("targetserver", res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res["targetserver"][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
|||||||
keyid = "ed25519:%s" % (testkey.version,)
|
keyid = "ed25519:%s" % (testkey.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
|
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn(self.hs.hostname, res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res[self.hs.hostname][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
|||||||
keyid = "ed25519:%s" % (self.hs_signing_key.version,)
|
keyid = "ed25519:%s" % (self.hs_signing_key.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
|
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn(self.hs.hostname, res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res[self.hs.hostname][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
|
@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase):
|
|||||||
self._pending_calls.append((values, d))
|
self._pending_calls.append((values, d))
|
||||||
return await make_deferred_yieldable(d)
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
def _get_sample_with_name(self, metric, name) -> int:
|
||||||
|
"""For a prometheus metric get the value of the sample that has a
|
||||||
|
matching "name" label.
|
||||||
|
"""
|
||||||
|
for sample in metric.collect()[0].samples:
|
||||||
|
if sample.labels.get("name") == name:
|
||||||
|
return sample.value
|
||||||
|
|
||||||
|
self.fail("Found no matching sample")
|
||||||
|
|
||||||
def _assert_metrics(self, queued, keys, in_flight):
|
def _assert_metrics(self, queued, keys, in_flight):
|
||||||
"""Assert that the metrics are correct"""
|
"""Assert that the metrics are correct"""
|
||||||
|
|
||||||
self.assertEqual(len(number_queued.collect()), 1)
|
sample = self._get_sample_with_name(number_queued, self.queue._name)
|
||||||
self.assertEqual(len(number_queued.collect()[0].samples), 1)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
number_queued.collect()[0].samples[0].labels,
|
sample,
|
||||||
{"name": self.queue._name},
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_queued.collect()[0].samples[0].value,
|
|
||||||
queued,
|
queued,
|
||||||
"number_queued",
|
"number_queued",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(number_of_keys.collect()), 1)
|
sample = self._get_sample_with_name(number_of_keys, self.queue._name)
|
||||||
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
|
self.assertEqual(sample, keys, "number_of_keys")
|
||||||
self.assertEqual(
|
|
||||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(number_in_flight.collect()), 1)
|
sample = self._get_sample_with_name(number_in_flight, self.queue._name)
|
||||||
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
sample,
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_in_flight.collect()[0].samples[0].value,
|
|
||||||
in_flight,
|
in_flight,
|
||||||
"number_in_flight",
|
"number_in_flight",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user