mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-27 21:57:04 -05:00
Add type hints to the crypto module. (#8999)
This commit is contained in:
parent
a685bbb018
commit
1c9a850562
1
changelog.d/8999.misc
Normal file
1
changelog.d/8999.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to the crypto module.
|
2
mypy.ini
2
mypy.ini
@ -17,6 +17,7 @@ files =
|
|||||||
synapse/api,
|
synapse/api,
|
||||||
synapse/appservice,
|
synapse/appservice,
|
||||||
synapse/config,
|
synapse/config,
|
||||||
|
synapse/crypto,
|
||||||
synapse/event_auth.py,
|
synapse/event_auth.py,
|
||||||
synapse/events/builder.py,
|
synapse/events/builder.py,
|
||||||
synapse/events/validator.py,
|
synapse/events/validator.py,
|
||||||
@ -75,6 +76,7 @@ files =
|
|||||||
synapse/storage/background_updates.py,
|
synapse/storage/background_updates.py,
|
||||||
synapse/storage/databases/main/appservice.py,
|
synapse/storage/databases/main/appservice.py,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
|
synapse/storage/databases/main/keys.py,
|
||||||
synapse/storage/databases/main/pusher.py,
|
synapse/storage/databases/main/pusher.py,
|
||||||
synapse/storage/databases/main/registration.py,
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
|
@ -227,7 +227,7 @@ class ConnectionVerifier:
|
|||||||
|
|
||||||
# This code is based on twisted.internet.ssl.ClientTLSOptions.
|
# This code is based on twisted.internet.ssl.ClientTLSOptions.
|
||||||
|
|
||||||
def __init__(self, hostname: bytes, verify_certs):
|
def __init__(self, hostname: bytes, verify_certs: bool):
|
||||||
self._verify_certs = verify_certs
|
self._verify_certs = verify_certs
|
||||||
|
|
||||||
_decoded = hostname.decode("ascii")
|
_decoded = hostname.decode("ascii")
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import collections.abc
|
import collections.abc
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import prune_event, prune_event_dict
|
from synapse.events.utils import prune_event, prune_event_dict
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Hasher = Callable[[bytes], "hashlib._Hash"]
|
||||||
|
|
||||||
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
|
||||||
|
def check_event_content_hash(
|
||||||
|
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
|
||||||
|
) -> bool:
|
||||||
"""Check whether the hash for this PDU matches the contents"""
|
"""Check whether the hash for this PDU matches the contents"""
|
||||||
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
|
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
|||||||
return message_hash_bytes == expected_hash
|
return message_hash_bytes == expected_hash
|
||||||
|
|
||||||
|
|
||||||
def compute_content_hash(event_dict, hash_algorithm):
|
def compute_content_hash(
|
||||||
|
event_dict: Dict[str, Any], hash_algorithm: Hasher
|
||||||
|
) -> Tuple[str, bytes]:
|
||||||
"""Compute the content hash of an event, which is the hash of the
|
"""Compute the content hash of an event, which is the hash of the
|
||||||
unredacted event.
|
unredacted event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_dict (dict): The unredacted event as a dict
|
event_dict: The unredacted event as a dict
|
||||||
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
||||||
to hash the event
|
to hash the event
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
|
A tuple of the name of hash and the hash as raw bytes.
|
||||||
bytes.
|
|
||||||
"""
|
"""
|
||||||
event_dict = dict(event_dict)
|
event_dict = dict(event_dict)
|
||||||
event_dict.pop("age_ts", None)
|
event_dict.pop("age_ts", None)
|
||||||
@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
|
|||||||
return hashed.name, hashed.digest()
|
return hashed.name, hashed.digest()
|
||||||
|
|
||||||
|
|
||||||
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
|
def compute_event_reference_hash(
|
||||||
|
event, hash_algorithm: Hasher = hashlib.sha256
|
||||||
|
) -> Tuple[str, bytes]:
|
||||||
"""Computes the event reference hash. This is the hash of the redacted
|
"""Computes the event reference hash. This is the hash of the redacted
|
||||||
event.
|
event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (FrozenEvent)
|
event
|
||||||
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
||||||
to hash the event
|
to hash the event
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
|
A tuple of the name of hash and the hash as raw bytes.
|
||||||
bytes.
|
|
||||||
"""
|
"""
|
||||||
tmp_event = prune_event(event)
|
tmp_event = prune_event(event)
|
||||||
event_dict = tmp_event.get_pdu_json()
|
event_dict = tmp_event.get_pdu_json()
|
||||||
@ -156,7 +163,7 @@ def add_hashes_and_signatures(
|
|||||||
event_dict: JsonDict,
|
event_dict: JsonDict,
|
||||||
signature_name: str,
|
signature_name: str,
|
||||||
signing_key: SigningKey,
|
signing_key: SigningKey,
|
||||||
):
|
) -> None:
|
||||||
"""Add content hash and sign the event
|
"""Add content hash and sign the event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -14,9 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import (
|
from signedjson.key import (
|
||||||
@ -40,6 +42,7 @@ from synapse.api.errors import (
|
|||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.config.key import TrustedKeyServer
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
@ -47,11 +50,15 @@ from synapse.logging.context import (
|
|||||||
run_in_background,
|
run_in_background,
|
||||||
)
|
)
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -61,16 +68,17 @@ class VerifyJsonRequest:
|
|||||||
A request to verify a JSON object.
|
A request to verify a JSON object.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
server_name(str): The name of the server to verify against.
|
server_name: The name of the server to verify against.
|
||||||
|
|
||||||
key_ids(set[str]): The set of key_ids to that could be used to verify the
|
json_object: The JSON object to verify.
|
||||||
JSON object
|
|
||||||
|
|
||||||
json_object(dict): The JSON object to verify.
|
minimum_valid_until_ts: time at which we require the signing key to
|
||||||
|
|
||||||
minimum_valid_until_ts (int): time at which we require the signing key to
|
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
|
request_name: The name of the request.
|
||||||
|
|
||||||
|
key_ids: The set of key_ids to that could be used to verify the JSON object
|
||||||
|
|
||||||
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
||||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||||
@ -80,12 +88,12 @@ class VerifyJsonRequest:
|
|||||||
errbacks with an M_UNAUTHORIZED SynapseError.
|
errbacks with an M_UNAUTHORIZED SynapseError.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server_name = attr.ib()
|
server_name = attr.ib(type=str)
|
||||||
json_object = attr.ib()
|
json_object = attr.ib(type=JsonDict)
|
||||||
minimum_valid_until_ts = attr.ib()
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
request_name = attr.ib()
|
request_name = attr.ib(type=str)
|
||||||
key_ids = attr.ib(init=False)
|
key_ids = attr.ib(init=False, type=List[str])
|
||||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred))
|
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
self.key_ids = signature_ids(self.json_object, self.server_name)
|
self.key_ids = signature_ids(self.json_object, self.server_name)
|
||||||
@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
|
|||||||
|
|
||||||
|
|
||||||
class Keyring:
|
class Keyring:
|
||||||
def __init__(self, hs, key_fetchers=None):
|
def __init__(
|
||||||
|
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||||
|
):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
if key_fetchers is None:
|
if key_fetchers is None:
|
||||||
@ -112,22 +122,26 @@ class Keyring:
|
|||||||
# completes.
|
# completes.
|
||||||
#
|
#
|
||||||
# These are regular, logcontext-agnostic Deferreds.
|
# These are regular, logcontext-agnostic Deferreds.
|
||||||
self.key_downloads = {}
|
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
||||||
|
|
||||||
def verify_json_for_server(
|
def verify_json_for_server(
|
||||||
self, server_name, json_object, validity_time, request_name
|
self,
|
||||||
):
|
server_name: str,
|
||||||
|
json_object: JsonDict,
|
||||||
|
validity_time: int,
|
||||||
|
request_name: str,
|
||||||
|
) -> defer.Deferred:
|
||||||
"""Verify that a JSON object has been signed by a given server
|
"""Verify that a JSON object has been signed by a given server
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str): name of the server which must have signed this object
|
server_name: name of the server which must have signed this object
|
||||||
|
|
||||||
json_object (dict): object to be checked
|
json_object: object to be checked
|
||||||
|
|
||||||
validity_time (int): timestamp at which we require the signing key to
|
validity_time: timestamp at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
request_name (str): an identifier for this json object (eg, an event id)
|
request_name: an identifier for this json object (eg, an event id)
|
||||||
for logging.
|
for logging.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -138,12 +152,14 @@ class Keyring:
|
|||||||
requests = (req,)
|
requests = (req,)
|
||||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||||
|
|
||||||
def verify_json_objects_for_server(self, server_and_json):
|
def verify_json_objects_for_server(
|
||||||
|
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
||||||
|
) -> List[defer.Deferred]:
|
||||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||||
necessary.
|
necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_and_json (iterable[Tuple[str, dict, int, str]):
|
server_and_json:
|
||||||
Iterable of (server_name, json_object, validity_time, request_name)
|
Iterable of (server_name, json_object, validity_time, request_name)
|
||||||
tuples.
|
tuples.
|
||||||
|
|
||||||
@ -164,13 +180,14 @@ class Keyring:
|
|||||||
for server_name, json_object, validity_time, request_name in server_and_json
|
for server_name, json_object, validity_time, request_name in server_and_json
|
||||||
)
|
)
|
||||||
|
|
||||||
def _verify_objects(self, verify_requests):
|
def _verify_objects(
|
||||||
|
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||||
|
) -> List[defer.Deferred]:
|
||||||
"""Does the work of verify_json_[objects_]for_server
|
"""Does the work of verify_json_[objects_]for_server
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verify_requests (iterable[VerifyJsonRequest]):
|
verify_requests: Iterable of verification requests.
|
||||||
Iterable of verification requests.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
List<Deferred[None]>: for each input item, a deferred indicating success
|
||||||
@ -182,7 +199,7 @@ class Keyring:
|
|||||||
key_lookups = []
|
key_lookups = []
|
||||||
handle = preserve_fn(_handle_key_deferred)
|
handle = preserve_fn(_handle_key_deferred)
|
||||||
|
|
||||||
def process(verify_request):
|
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
||||||
"""Process an entry in the request list
|
"""Process an entry in the request list
|
||||||
|
|
||||||
Adds a key request to key_lookups, and returns a deferred which
|
Adds a key request to key_lookups, and returns a deferred which
|
||||||
@ -222,18 +239,20 @@ class Keyring:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _start_key_lookups(self, verify_requests):
|
async def _start_key_lookups(
|
||||||
|
self, verify_requests: List[VerifyJsonRequest]
|
||||||
|
) -> None:
|
||||||
"""Sets off the key fetches for each verify request
|
"""Sets off the key fetches for each verify request
|
||||||
|
|
||||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verify_requests (List[VerifyJsonRequest]):
|
verify_requests:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# map from server name to a set of outstanding request ids
|
# map from server name to a set of outstanding request ids
|
||||||
server_to_request_ids = {}
|
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
server_name = verify_request.server_name
|
server_name = verify_request.server_name
|
||||||
@ -275,11 +294,11 @@ class Keyring:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error starting key lookups")
|
logger.exception("Error starting key lookups")
|
||||||
|
|
||||||
async def wait_for_previous_lookups(self, server_names) -> None:
|
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_names (Iterable[str]): list of servers which we want to look up
|
server_names: list of servers which we want to look up
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once all key lookups for the given servers have
|
Resolves once all key lookups for the given servers have
|
||||||
@ -304,7 +323,7 @@ class Keyring:
|
|||||||
|
|
||||||
loop_count += 1
|
loop_count += 1
|
||||||
|
|
||||||
def _get_server_verify_keys(self, verify_requests):
|
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
||||||
"""Tries to find at least one key for each verify request
|
"""Tries to find at least one key for each verify request
|
||||||
|
|
||||||
For each verify_request, verify_request.key_ready is called back with
|
For each verify_request, verify_request.key_ready is called back with
|
||||||
@ -312,7 +331,7 @@ class Keyring:
|
|||||||
with a SynapseError if none of the keys are found.
|
with a SynapseError if none of the keys are found.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verify_requests (list[VerifyJsonRequest]): list of verify requests
|
verify_requests: list of verify requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||||
@ -366,17 +385,19 @@ class Keyring:
|
|||||||
|
|
||||||
run_in_background(do_iterations)
|
run_in_background(do_iterations)
|
||||||
|
|
||||||
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
async def _attempt_key_fetches_with_fetcher(
|
||||||
|
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
||||||
|
):
|
||||||
"""Use a key fetcher to attempt to satisfy some key requests
|
"""Use a key fetcher to attempt to satisfy some key requests
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
fetcher: fetcher to use to fetch the keys
|
||||||
remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
|
remaining_requests: outstanding key requests.
|
||||||
Any successfully-completed requests will be removed from the list.
|
Any successfully-completed requests will be removed from the list.
|
||||||
"""
|
"""
|
||||||
# dict[str, dict[str, int]]: keys to fetch.
|
# The keys to fetch.
|
||||||
# server_name -> key_id -> min_valid_ts
|
# server_name -> key_id -> min_valid_ts
|
||||||
missing_keys = defaultdict(dict)
|
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
||||||
|
|
||||||
for verify_request in remaining_requests:
|
for verify_request in remaining_requests:
|
||||||
# any completed requests should already have been removed
|
# any completed requests should already have been removed
|
||||||
@ -438,16 +459,18 @@ class Keyring:
|
|||||||
remaining_requests.difference_update(completed)
|
remaining_requests.difference_update(completed)
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher:
|
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||||
async def get_keys(self, keys_to_fetch):
|
@abc.abstractmethod
|
||||||
|
async def get_keys(
|
||||||
|
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, dict[str, int]]):
|
keys_to_fetch:
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
Map from server_name -> key_id -> FetchKeyResult
|
||||||
map from server_name -> key_id -> FetchKeyResult
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -455,31 +478,35 @@ class KeyFetcher:
|
|||||||
class StoreKeyFetcher(KeyFetcher):
|
class StoreKeyFetcher(KeyFetcher):
|
||||||
"""KeyFetcher impl which fetches keys from our data store"""
|
"""KeyFetcher impl which fetches keys from our data store"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def get_keys(self, keys_to_fetch):
|
async def get_keys(
|
||||||
|
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
keys_to_fetch = (
|
key_ids_to_fetch = (
|
||||||
(server_name, key_id)
|
(server_name, key_id)
|
||||||
for server_name, keys_for_server in keys_to_fetch.items()
|
for server_name, keys_for_server in keys_to_fetch.items()
|
||||||
for key_id in keys_for_server.keys()
|
for key_id in keys_for_server.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
res = await self.store.get_server_verify_keys(keys_to_fetch)
|
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||||
keys = {}
|
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||||
for (server_name, key_id), key in res.items():
|
for (server_name, key_id), key in res.items():
|
||||||
keys.setdefault(server_name, {})[key_id] = key
|
keys.setdefault(server_name, {})[key_id] = key
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
class BaseV2KeyFetcher:
|
class BaseV2KeyFetcher(KeyFetcher):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.config = hs.get_config()
|
self.config = hs.get_config()
|
||||||
|
|
||||||
async def process_v2_response(self, from_server, response_json, time_added_ms):
|
async def process_v2_response(
|
||||||
|
self, from_server: str, response_json: JsonDict, time_added_ms: int
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"""Parse a 'Server Keys' structure from the result of a /key request
|
"""Parse a 'Server Keys' structure from the result of a /key request
|
||||||
|
|
||||||
This is used to parse either the entirety of the response from
|
This is used to parse either the entirety of the response from
|
||||||
@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
|
|||||||
to /_matrix/key/v2/query.
|
to /_matrix/key/v2/query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
from_server (str): the name of the server producing this result: either
|
from_server: the name of the server producing this result: either
|
||||||
the origin server for a /_matrix/key/v2/server request, or the notary
|
the origin server for a /_matrix/key/v2/server request, or the notary
|
||||||
for a /_matrix/key/v2/query.
|
for a /_matrix/key/v2/query.
|
||||||
|
|
||||||
response_json (dict): the json-decoded Server Keys response object
|
response_json: the json-decoded Server Keys response object
|
||||||
|
|
||||||
time_added_ms (int): the timestamp to record in server_keys_json
|
time_added_ms: the timestamp to record in server_keys_json
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
|
Map from key_id to result object
|
||||||
"""
|
"""
|
||||||
ts_valid_until_ms = response_json["valid_until_ts"]
|
ts_valid_until_ms = response_json["valid_until_ts"]
|
||||||
|
|
||||||
@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
|
|||||||
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
|
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
self.key_servers = self.config.key_servers
|
self.key_servers = self.config.key_servers
|
||||||
|
|
||||||
async def get_keys(self, keys_to_fetch):
|
async def get_keys(
|
||||||
|
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
async def get_key(key_server):
|
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||||
try:
|
try:
|
||||||
result = await self.get_server_verify_key_v2_indirect(
|
return await self.get_server_verify_key_v2_indirect(
|
||||||
keys_to_fetch, key_server
|
keys_to_fetch, key_server
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
except KeyLookupError as e:
|
except KeyLookupError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Key lookup failed from %r: %s", key_server.server_name, e
|
"Key lookup failed from %r: %s", key_server.server_name, e
|
||||||
@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
)
|
)
|
||||||
|
|
||||||
union_of_keys = {}
|
union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||||
for result in results:
|
for result in results:
|
||||||
for server_name, keys in result.items():
|
for server_name, keys in result.items():
|
||||||
union_of_keys.setdefault(server_name, {}).update(keys)
|
union_of_keys.setdefault(server_name, {}).update(keys)
|
||||||
|
|
||||||
return union_of_keys
|
return union_of_keys
|
||||||
|
|
||||||
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
async def get_server_verify_key_v2_indirect(
|
||||||
|
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, dict[str, int]]):
|
keys_to_fetch:
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||||
|
|
||||||
key_server (synapse.config.key.TrustedKeyServer): notary server to query for
|
key_server: notary server to query for the keys
|
||||||
the keys
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
|
Map from server_name -> key_id -> FetchKeyResult
|
||||||
from server_name -> key_id -> FetchKeyResult
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyLookupError if there was an error processing the entire response from
|
KeyLookupError if there was an error processing the entire response from
|
||||||
@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
||||||
|
|
||||||
keys = {}
|
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||||
added_keys = []
|
added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
|
||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
assert isinstance(query_response, dict)
|
||||||
for response in query_response["server_keys"]:
|
for response in query_response["server_keys"]:
|
||||||
# do this first, so that we can give useful errors thereafter
|
# do this first, so that we can give useful errors thereafter
|
||||||
server_name = response.get("server_name")
|
server_name = response.get("server_name")
|
||||||
@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
def _validate_perspectives_response(self, key_server, response):
|
def _validate_perspectives_response(
|
||||||
|
self, key_server: TrustedKeyServer, response: JsonDict
|
||||||
|
) -> None:
|
||||||
"""Optionally check the signature on the result of a /key/query request
|
"""Optionally check the signature on the result of a /key/query request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key_server (synapse.config.key.TrustedKeyServer): the notary server that
|
key_server: the notary server that produced this result
|
||||||
produced this result
|
|
||||||
|
|
||||||
response (dict): the json-decoded Server Keys response object
|
response: the json-decoded Server Keys response object
|
||||||
"""
|
"""
|
||||||
perspective_name = key_server.server_name
|
perspective_name = key_server.server_name
|
||||||
perspective_keys = key_server.verify_keys
|
perspective_keys = key_server.verify_keys
|
||||||
@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||||||
class ServerKeyFetcher(BaseV2KeyFetcher):
|
class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
"""KeyFetcher impl which fetches keys from the origin servers"""
|
"""KeyFetcher impl which fetches keys from the origin servers"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
|
|
||||||
async def get_keys(self, keys_to_fetch):
|
async def get_keys(
|
||||||
|
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch (dict[str, iterable[str]]):
|
keys_to_fetch:
|
||||||
the keys to be fetched. server_name -> key_ids
|
the keys to be fetched. server_name -> key_ids
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
|
Map from server_name -> key_id -> FetchKeyResult
|
||||||
map from server_name -> key_id -> FetchKeyResult
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
async def get_key(key_to_fetch_item):
|
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
||||||
server_name, key_ids = key_to_fetch_item
|
server_name, key_ids = key_to_fetch_item
|
||||||
try:
|
try:
|
||||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||||
@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
async def get_server_verify_key_v2_direct(
|
||||||
|
self, server_name: str, key_ids: Iterable[str]
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str):
|
server_name:
|
||||||
key_ids (iterable[str]):
|
key_ids:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, FetchKeyResult]: map from key ID to lookup result
|
Map from key ID to lookup result
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyLookupError if there was a problem making the lookup
|
KeyLookupError if there was a problem making the lookup
|
||||||
"""
|
"""
|
||||||
keys = {} # type: dict[str, FetchKeyResult]
|
keys = {} # type: Dict[str, FetchKeyResult]
|
||||||
|
|
||||||
for requested_key_id in key_ids:
|
for requested_key_id in key_ids:
|
||||||
# we may have found this key as a side-effect of asking for another.
|
# we may have found this key as a side-effect of asking for another.
|
||||||
@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
||||||
|
|
||||||
|
assert isinstance(response, dict)
|
||||||
if response["server_name"] != server_name:
|
if response["server_name"] != server_name:
|
||||||
raise KeyLookupError(
|
raise KeyLookupError(
|
||||||
"Expected a response for server %r not %r"
|
"Expected a response for server %r not %r"
|
||||||
@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
async def _handle_key_deferred(verify_request) -> None:
|
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
||||||
"""Waits for the key to become available, and then performs a verification
|
"""Waits for the key to become available, and then performs a verification
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verify_request (VerifyJsonRequest):
|
verify_request:
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem performing the verification
|
SynapseError if there was a problem performing the verification
|
||||||
|
@ -144,7 +144,7 @@ class Authenticator:
|
|||||||
):
|
):
|
||||||
raise FederationDeniedError(origin)
|
raise FederationDeniedError(origin)
|
||||||
|
|
||||||
if not json_request["signatures"]:
|
if origin is None or not json_request["signatures"]:
|
||||||
raise NoAuthenticationError(
|
raise NoAuthenticationError(
|
||||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Set
|
from typing import Dict
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
cache_misses = {} # type: Dict[str, Set[str]]
|
# Note that the value is unused.
|
||||||
|
cache_misses = {} # type: Dict[str, Dict[str, int]]
|
||||||
for (server_name, key_id, from_server), results in cached.items():
|
for (server_name, key_id, from_server), results in cached.items():
|
||||||
results = [(result["ts_added_ms"], result) for result in results]
|
results = [(result["ts_added_ms"], result) for result in results]
|
||||||
|
|
||||||
if not results and key_id is not None:
|
if not results and key_id is not None:
|
||||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
cache_misses.setdefault(server_name, {})[key_id] = 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if key_id is not None:
|
if key_id is not None:
|
||||||
@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if miss:
|
if miss:
|
||||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
cache_misses.setdefault(server_name, {})[key_id] = 0
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
json_results.add(bytes(most_recent_result["key_json"]))
|
json_results.add(bytes(most_recent_result["key_json"]))
|
||||||
else:
|
else:
|
||||||
|
@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
|
|||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
async def get_server_verify_keys(
|
async def get_server_verify_keys(
|
||||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
||||||
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
|
) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
server_name_and_key_ids:
|
server_name_and_key_ids:
|
||||||
@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
keys = {}
|
keys = {}
|
||||||
|
|
||||||
def _get_keys(txn, batch):
|
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
|
||||||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||||
|
|
||||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||||
@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
|
|||||||
# `ts_valid_until_ms`.
|
# `ts_valid_until_ms`.
|
||||||
ts_valid_until_ms = 0
|
ts_valid_until_ms = 0
|
||||||
|
|
||||||
res = FetchKeyResult(
|
keys[(server_name, key_id)] = FetchKeyResult(
|
||||||
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
|
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
|
||||||
valid_until_ts=ts_valid_until_ms,
|
valid_until_ts=ts_valid_until_ms,
|
||||||
)
|
)
|
||||||
keys[(server_name, key_id)] = res
|
|
||||||
|
|
||||||
def _txn(txn):
|
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||||
for batch in batch_iter(server_name_and_key_ids, 50):
|
for batch in batch_iter(server_name_and_key_ids, 50):
|
||||||
_get_keys(txn, batch)
|
_get_keys(txn, batch)
|
||||||
return keys
|
return keys
|
||||||
|
@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||||
mock_fetcher = keyring.KeyFetcher()
|
mock_fetcher = Mock()
|
||||||
mock_fetcher.get_keys = Mock()
|
mock_fetcher.get_keys = Mock()
|
||||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
"""Tests that we correctly handle key requests for keys we've stored
|
"""Tests that we correctly handle key requests for keys we've stored
|
||||||
with a null `ts_valid_until_ms`
|
with a null `ts_valid_until_ms`
|
||||||
"""
|
"""
|
||||||
mock_fetcher = keyring.KeyFetcher()
|
mock_fetcher = Mock()
|
||||||
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
|
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
|
||||||
|
|
||||||
kr = keyring.Keyring(
|
kr = keyring.Keyring(
|
||||||
@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_fetcher = keyring.KeyFetcher()
|
mock_fetcher = Mock()
|
||||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||||
|
|
||||||
@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_fetcher1 = keyring.KeyFetcher()
|
mock_fetcher1 = Mock()
|
||||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||||
mock_fetcher2 = keyring.KeyFetcher()
|
mock_fetcher2 = Mock()
|
||||||
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
|
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
|
||||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
|
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user