diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix new file mode 100644 index 000000000..492e58a7a --- /dev/null +++ b/changelog.d/10593.bugfix @@ -0,0 +1 @@ +Reject Client-Server /keys/query requests which provide device_ids incorrectly. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc662bca8..9480f448d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -147,6 +147,14 @@ class SynapseError(CodeMessageException): return cs_error(self.msg, self.errcode) +class InvalidAPICallError(SynapseError): + """You called an existing API endpoint, but fed that endpoint + invalid or incomplete data.""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) + + class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index d0d9d30d4..012491f59 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,8 +15,9 @@ # limitations under the License. import logging +from typing import Any -from synapse.api.errors import SynapseError +from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet): device_id = requester.device_id timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) + + device_keys = body.get("device_keys") + if not isinstance(device_keys, dict): + raise InvalidAPICallError("'device_keys' must be a JSON object") + + def is_list_of_strings(values: Any) -> bool: + return isinstance(values, list) and all(isinstance(v, str) for v in values) + + if any(not is_list_of_strings(keys) for keys in device_keys.values()): + raise InvalidAPICallError( + "'device_keys' values must be a list of strings", + ) + result = await self.e2e_keys_handler.query_devices( body, timeout, user_id, device_id ) diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py new file mode 100644 index 000000000..80a4e728f --- /dev/null +++ b/tests/rest/client/v2_alpha/test_keys.py @@ -0,0 +1,77 @@ +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + )