Return keys for unwhitelisted servers from /_matrix/key/v2/query (#13683)

This commit is contained in:
Richard van der Hoff 2022-09-01 13:54:02 +01:00 committed by GitHub
parent 18e4092801
commit e8130f219b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 20 deletions

1
changelog.d/13683.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug which meant that keys for unwhitelisted servers were not returned by `/_matrix/key/v2/query`.

View File

@ -135,13 +135,6 @@ class RemoteKey(DirectServeJsonResource):
store_queries = [] store_queries = []
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if (
self.federation_domain_whitelist is not None
and server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
if not key_ids: if not key_ids:
key_ids = (None,) key_ids = (None,)
for key_id in key_ids: for key_id in key_ids:
@ -153,21 +146,28 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
# Note that the value is unused. # Map server_name->key_id->int. Note that the value of the init is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {} cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items(): for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results] results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None: if key_id is None:
cache_misses.setdefault(server_name, {})[key_id] = 0 # all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
continue continue
if key_id is not None: miss = False
if not results:
miss = True
else:
ts_added_ms, most_recent_result = max(results) ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {}) req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts") req_valid_until = req_key.get("minimum_valid_until_ts")
miss = False
if req_valid_until is not None: if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until: if ts_valid_until_ms < req_valid_until:
logger.debug( logger.debug(
@ -211,19 +211,20 @@ class RemoteKey(DirectServeJsonResource):
ts_valid_until_ms, ts_valid_until_ms,
time_now_ms, time_now_ms,
) )
if miss:
cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview. # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"])) json_results.add(bytes(most_recent_result["key_json"]))
else:
for _, result in results: if miss and query_remote_on_cache_miss:
# Cast to bytes since postgresql returns a memoryview. # only bother attempting to fetch keys from servers on our whitelist
json_results.add(bytes(result["key_json"])) if (
self.federation_domain_whitelist is None
or server_name in self.federation_domain_whitelist
):
cache_misses.setdefault(server_name, {})[key_id] = 0
# If there is a cache miss, request the missing keys, then recurse (and # If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent). # ensure the result is sent).
if cache_misses and query_remote_on_cache_miss: if cache_misses:
await yieldable_gather_results( await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t), lambda t: self.fetcher.get_keys(*t),
( (