diff --git a/changelog.d/15663.misc b/changelog.d/15663.misc new file mode 100644 index 000000000..cc5f80154 --- /dev/null +++ b/changelog.d/15663.misc @@ -0,0 +1 @@ +Add requesting user id parameter to key claim methods in `TransportLayerClient`. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 076b9287c..a2cf3a96c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -236,6 +236,7 @@ class FederationClient(FederationBase): async def claim_client_keys( self, + user: UserID, destination: str, query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], @@ -243,6 +244,7 @@ class FederationClient(FederationBase): """Claims one-time keys for a device hosted on a remote server. Args: + user: The user id of the requesting user destination: Domain name of the remote homeserver content: The query content. @@ -279,7 +281,7 @@ class FederationClient(FederationBase): if use_unstable: try: return await self.transport_layer.claim_client_keys_unstable( - destination, unstable_content, timeout + user, destination, unstable_content, timeout ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -295,7 +297,7 @@ class FederationClient(FederationBase): logger.debug("Skipping unstable claim client keys API") return await self.transport_layer.claim_client_keys( - destination, content, timeout + user, destination, content, timeout ) @trace diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1cfc4446c..0b17f713e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser from synapse.http.types import QueryParams -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import ExceptionBundle if TYPE_CHECKING: @@ -630,7 +630,11 @@ class TransportLayerClient: ) async def claim_client_keys( - self, destination: str, query_content: JsonDict, timeout: Optional[int] + self, + user: UserID, + destination: str, + query_content: JsonDict, + timeout: Optional[int], ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -655,6 +659,7 @@ class TransportLayerClient: } Args: + user: the user_id of the requesting user destination: The server to query. query_content: The user ids to query. Returns: @@ -671,7 +676,11 @@ class TransportLayerClient: ) async def claim_client_keys_unstable( - self, destination: str, query_content: JsonDict, timeout: Optional[int] + self, + user: UserID, + destination: str, + query_content: JsonDict, + timeout: Optional[int], ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -696,6 +705,7 @@ class TransportLayerClient: } Args: + user: the user_id of the requesting user destination: The server to query. query_content: The user ids to query. Returns: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 24741b667..ad075497c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -661,6 +661,7 @@ class E2eKeysHandler: async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, int]]], + user: UserID, timeout: Optional[int], always_include_fallback_keys: bool, ) -> JsonDict: @@ -703,7 +704,7 @@ class E2eKeysHandler: device_keys = remote_queries[destination] try: remote_result = await self.federation.claim_client_keys( - destination, device_keys, timeout=timeout + user, destination, device_keys, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 9bbab5e62..413edd8a4 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -287,7 +287,7 @@ class OneTimeKeyServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -298,7 +298,7 @@ class OneTimeKeyServlet(RestServlet): query.setdefault(user_id, {})[device_id] = {algorithm: 1} result = await self.e2e_keys_handler.claim_one_time_keys( - query, timeout, always_include_fallback_keys=False + query, requester.user, timeout, always_include_fallback_keys=False ) return 200, result @@ -335,7 +335,7 @@ class UnstableOneTimeKeyServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -346,7 +346,7 @@ class UnstableOneTimeKeyServlet(RestServlet): query.setdefault(user_id, {})[device_id] = Counter(algorithms) result = await self.e2e_keys_handler.claim_one_time_keys( - query, timeout, always_include_fallback_keys=True + query, requester.user, timeout, always_include_fallback_keys=True ) return 200, result diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 72d058406..2eaffe511 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -27,7 +27,7 @@ from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -45,6 +45,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main + self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}") def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" @@ -161,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res2 = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -206,6 +208,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -225,6 +228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -274,6 +278,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -286,6 +291,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -307,6 +313,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -348,6 +355,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -370,6 +378,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1080,6 +1089,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -1125,6 +1135,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1169,6 +1180,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1202,6 +1214,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1229,6 +1242,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, )