Limit the number of in-flight /keys/query requests from a single device. (#10144)

This commit is contained in:
Patrick Cloke 2021-06-09 07:05:32 -04:00 committed by GitHub
parent 1bf83a191b
commit 11846dff8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 198 additions and 175 deletions

1
changelog.d/10144.misc Normal file
View File

@ -0,0 +1 @@
Limit the number of in-flight `/keys/query` requests from a single device.

View File

@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys
)
# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
)
@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@ -105,8 +111,10 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
@ -143,12 +151,16 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
query_list.extend(
(user_id, device_id) for device_id in device_ids
)
else:
query_list.append((user_id, None))

View File

@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
return 200, result

View File

@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
{"device_keys": {local_user: [], other_user: []}},
0,
local_user,
"device123",
)
)