diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 0d24aa7ac..bfe6e6160 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -289,6 +289,7 @@ class Keyring(object): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_key.time_added = time_now_ms verify_keys[key_id] = verify_key old_verify_keys = {} diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 69bc15ba7..e434847b4 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,6 +13,7 @@ # limitations under the License. from synapse.http.server import request_handler, respond_with_json_bytes +from synapse.http.servlet import parse_integer from synapse.api.errors import SynapseError, Codes from twisted.web.resource import Resource @@ -44,7 +45,13 @@ class RemoteKey(Resource): POST /_matrix/v2/query HTTP/1.1 Content-Type: application/json { - "server_keys": { "remote.server.example.com": ["a.key.id"] } + "server_keys": { + "remote.server.example.com": { + "a.key.id": { + "minimum_valid_until_ts": 1234567890123 + } + } + } } Response: @@ -96,10 +103,16 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server: [None]} + query = {server: {}} elif len(request.postpath) == 2: server, key_id = request.postpath - query = {server: [key_id]} + minimum_valid_until_ts = parse_integer( + request, "minimum_valid_until_ts" + ) + arguments = {} + if minimum_valid_until_ts is not None: + arguments["minimum_valid_until_ts"] = minimum_valid_until_ts + query = {server: {key_id: arguments}} else: raise SynapseError( 404, "Not found %r" % request.postpath, Codes.NOT_FOUND @@ -128,8 +141,11 @@ class RemoteKey(Resource): @defer.inlineCallbacks def query_keys(self, request, query, query_remote_on_cache_miss=False): + logger.info("Handling query for keys %r", query) store_queries = [] for server_name, key_ids in query.items(): + if not key_ids: + key_ids = (None,) for key_id in key_ids: store_queries.append((server_name, key_id, None)) @@ -152,9 +168,44 @@ class RemoteKey(Resource): if key_id is not None: ts_added_ms, most_recent_result = max(results) ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] - if (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms: + req_key = query.get(server_name, {}).get(key_id, {}) + req_valid_until = req_key.get("minimum_valid_until_ts") + miss = False + if req_valid_until is not None: + if ts_valid_until_ms < req_valid_until: + logger.debug( + "Cached response for %r/%r is older than requested" + ": valid_until (%r) < minimum_valid_until (%r)", + server_name, key_id, + ts_valid_until_ms, req_valid_until + ) + miss = True + else: + logger.debug( + "Cached response for %r/%r is newer than requested" + ": valid_until (%r) >= minimum_valid_until (%r)", + server_name, key_id, + ts_valid_until_ms, req_valid_until + ) + elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms: + logger.debug( + "Cached response for %r/%r is too old" + ": (added (%r) + valid_until (%r)) / 2 < now (%r)", + server_name, key_id, + ts_added_ms, ts_valid_until_ms, time_now_ms + ) # We more than half way through the lifetime of the # response. We should fetch a fresh copy. + miss = True + else: + logger.debug( + "Cached response for %r/%r is still valid" + ": (added (%r) + valid_until (%r)) / 2 < now (%r)", + server_name, key_id, + ts_added_ms, ts_valid_until_ms, time_now_ms + ) + + if miss: cache_misses.setdefault(server_name, set()).add(key_id) json_results.add(bytes(most_recent_result["key_json"])) else: