From 90abdaf3bcd3eceb6dd5688f9ef3623382883173 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 1 Jun 2015 10:51:50 +0100 Subject: [PATCH 01/59] Use Twisted-15.2.1, Use Agent.usingEndpointFactory rather than implement our own Agent --- synapse/http/matrixfederationclient.py | 75 +++++++++----------------- synapse/python_dependencies.py | 2 +- 2 files changed, 26 insertions(+), 51 deletions(-) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 7f3d8fc88..ec5b06ddc 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -16,7 +16,7 @@ from twisted.internet import defer, reactor, protocol from twisted.internet.error import DNSLookupError -from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool +from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone @@ -53,41 +53,17 @@ incoming_responses_counter = metrics.register_counter( ) -class MatrixFederationHttpAgent(_AgentBase): +class MatrixFederationEndpointFactory(object): + def __init__(self, hs): + self.tls_context_factory = hs.tls_context_factory - def __init__(self, reactor, pool=None): - _AgentBase.__init__(self, reactor, pool) + def endpointForURI(self, uri): + destination = uri.netloc - def request(self, destination, endpoint, method, path, params, query, - headers, body_producer): - - outgoing_requests_counter.inc(method) - - host = b"" - port = 0 - fragment = b"" - - parsed_URI = _URI(b"http", destination, host, port, path, params, - query, fragment) - - # Set the connection pool key to be the destination. - key = destination - - d = self._requestWithEndpoint(key, endpoint, method, parsed_URI, - headers, body_producer, - parsed_URI.originForm) - - def _cb(response): - incoming_responses_counter.inc(method, response.code) - return response - - def _eb(failure): - incoming_responses_counter.inc(method, "ERR") - return failure - - d.addCallbacks(_cb, _eb) - - return d + return matrix_federation_endpoint( + reactor, destination, timeout=10, + ssl_context_factory=self.tls_context_factory + ) class MatrixFederationHttpClient(object): @@ -105,10 +81,17 @@ class MatrixFederationHttpClient(object): self.server_name = hs.hostname pool = HTTPConnectionPool(reactor) pool.maxPersistentPerHost = 10 - self.agent = MatrixFederationHttpAgent(reactor, pool=pool) + self.agent = Agent.usingEndpointFactory( + reactor, MatrixFederationEndpointFactory(hs), pool=pool + ) self.clock = hs.get_clock() self.version_string = hs.version_string + def _create_url(self, destination, path_bytes, param_bytes, query_bytes): + return urlparse.urlunparse( + ("matrix", destination, path_bytes, param_bytes, query_bytes, "") + ) + @defer.inlineCallbacks def _create_request(self, destination, method, path_bytes, body_callback, headers_dict={}, param_bytes=b"", @@ -119,8 +102,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"Host"] = [destination] - url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "",) + url_bytes = self._create_url( + destination, path_bytes, param_bytes, query_bytes ) logger.info("Sending request to %s: %s %s", @@ -139,22 +122,20 @@ class MatrixFederationHttpClient(object): # (once we have reliable transactions in place) retries_left = 5 - endpoint = self._getEndpoint(reactor, destination) + http_url_bytes = urlparse.urlunparse( + ("", "", path_bytes, param_bytes, query_bytes, "") + ) while True: producer = None if body_callback: - producer = body_callback(method, url_bytes, headers_dict) + producer = body_callback(method, http_url_bytes, headers_dict) try: request_deferred = preserve_context_over_fn( self.agent.request, - destination, - endpoint, method, - path_bytes, - param_bytes, - query_bytes, + url_bytes, Headers(headers_dict), producer ) @@ -442,12 +423,6 @@ class MatrixFederationHttpClient(object): defer.returnValue((length, headers)) - def _getEndpoint(self, reactor, destination): - return matrix_federation_endpoint( - reactor, destination, timeout=10, - ssl_context_factory=self.hs.tls_context_factory - ) - class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index a45dd3c93..d1bf6d50c 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil>=0.0.6": ["syutil>=0.0.6"], - "Twisted==14.0.2": ["twisted==14.0.2"], + "Twisted==15.2.1": ["twisted==15.2.1"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], From 66da8f60d057527fe74ff5790973d2865c980efb Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 15 Jun 2015 16:27:20 +0100 Subject: [PATCH 02/59] Bump the version of twisted needed for setup_requires to 15.2.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f9929591e..c790484c7 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ setup( description="Reference Synapse Home Server", install_requires=dependencies['requirements'](include_conditional=True).keys(), setup_requires=[ - "Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 + "Twisted==15.2.1", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "setuptools_trial", "mock" ], From 62c010283d543db0956066b42eb735b57c000a82 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 23 Jul 2015 16:03:38 +0100 Subject: [PATCH 03/59] Add federation support for end-to-end key requests --- synapse/federation/federation_client.py | 34 ++++++++ synapse/federation/federation_server.py | 37 +++++++++ synapse/federation/transport/client.py | 70 +++++++++++++++++ synapse/federation/transport/server.py | 20 +++++ synapse/rest/client/v2_alpha/keys.py | 100 +++++++++++++++++------- 5 files changed, 231 insertions(+), 30 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb..21a86a4c6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -134,6 +134,40 @@ class FederationClient(FederationBase): destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail ) + @log_function + def query_client_keys(self, destination, content, retry_on_dns_fail=True): + """Query device keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_device_keys") + return self.transport_layer.query_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + + @log_function + def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + """Claims one-time keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_one_time_keys") + return self.transport_layer.claim_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + @defer.inlineCallbacks @log_function def backfill(self, dest, context, limit, extremities): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cd79e23f4..c32908ac2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature +import simplejson as json import logging @@ -312,6 +313,42 @@ class FederationServer(FederationBase): (200, send_content) ) + @defer.inlineCallbacks + @log_function + def on_query_client_keys(self, origin, content): + query = [] + for user_id, device_ids in content.get("device_keys", {}).items(): + if not device_ids: + query.append((user_id, None)) + else: + for device_id in device_ids: + query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[device_id] = json.loads( + json_bytes + ) + defer.returnValue({"device_keys": json_result}) + + @defer.inlineCallbacks + @log_function + def on_claim_client_keys(self, origin, content): + query = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + defer.returnValue({"one_time_keys": json_result}) + @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 610a4c316..df5083dd2 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -222,6 +222,76 @@ class TransportLayerClient(object): defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_client_keys(self, destination, query_content): + """Query the device keys for a list of user ids hosted on a remote + server. + + Request: + { + "device_keys": { + "": [""] + } } + + Response: + { + "device_keys": { + "": { + "": {...} + } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/client_keys/query" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def claim_client_keys(self, destination, query_content): + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "": { + "": "" + } } } + + Response: + { + "device_keys": { + "": { + "": { + ":": "" + } } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the one-time keys. + """ + + path = PREFIX + "/client_keys/claim" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def get_missing_events(self, destination, room_id, earliest_events, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bad93c6b2..fb59383ec 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet): defer.returnValue((200, content)) +class FederationClientKeysQueryServlet(BaseFederationServlet): + PATH = "/client_keys/query" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_query(origin, content) + defer.returnValue((200, response)) + + +class FederationClientKeysClaimServlet(BaseFederationServlet): + PATH = "/client_keys/claim" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_claim(origin, content) + defer.returnValue((200, response)) + + class FederationQueryAuthServlet(BaseFederationServlet): PATH = "/query_auth/([^/]*)/([^/]*)" @@ -373,4 +391,6 @@ SERVLET_CLASSES = ( FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationClientKeysClaimServlet, ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5f3a6207b..739a08ada 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet +from synapse.types import UserID from syutil.jsonutil import encode_canonical_json from ._base import client_v2_pattern @@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet): super(KeyQueryServlet, self).__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - logger.debug("onPOST") yield self.auth.get_user_by_req(request) try: body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_ids in body.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - results = yield self.store.get_e2e_device_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() - if not user_id: - user_id = auth_user_id - if not device_id: - device_id = None - # Returns a map of user_id->device_id->json_bytes. - results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) - defer.returnValue(self.json_result(request, results)) + user_id = user_id if user_id else auth_user_id + device_ids = [device_id] if device_id else [] + result = yield self.handle_request( + {"device_keys": {user_id: device_ids}} + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_ids in body.get("device_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + else: + remote_queries.set_default(user.domain, {})[user_id] = list( + device_ids + ) + results = yield self.store.get_e2e_device_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) - return (200, {"device_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"device_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + defer.returnValue((200, {"device_keys": json_result})) class OneTimeKeyServlet(RestServlet): @@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - results = yield self.store.claim_e2e_one_time_keys( - [(user_id, device_id, algorithm)] + result = yield self.handle_request( + {"one_time_keys": {user_id: {device_id: algorithm}}} ) - defer.returnValue(self.json_result(request, results)) + defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): @@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet): body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_keys in body.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) - results = yield self.store.claim_e2e_one_time_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_keys in body.get("one_time_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + remote_queries.set_default(user.domain, {})[user_id] = ( + device_keys + ) + results = yield self.store.claim_e2e_one_time_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } - return (200, {"one_time_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"one_time_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + + defer.returnValue((200, {"one_time_keys": json_result})) def register_servlets(hs, http_server): From 2da3b1e60bf7e9ae1d6714abcff0a0c224cadf28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 24 Jul 2015 18:26:46 +0100 Subject: [PATCH 04/59] Get the end-to-end key federation working --- synapse/federation/federation_client.py | 12 ++++-------- synapse/federation/transport/client.py | 4 ++-- synapse/federation/transport/server.py | 12 ++++++------ synapse/rest/client/v2_alpha/keys.py | 10 +++++----- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 21a86a4c6..44e4d0755 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content, retry_on_dns_fail=True): + def query_client_keys(self, destination, content): """Query device keys for a device hosted on a remote server. Args: @@ -147,12 +147,10 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.query_client_keys(destination, content) @log_function - def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + def claim_client_keys(self, destination, content): """Claims one-time keys for a device hosted on a remote server. Args: @@ -164,9 +162,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.claim_client_keys(destination, content) @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index df5083dd2..ced703364 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -247,7 +247,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/client_keys/query" + path = PREFIX + "/user/keys/query" content = yield self.client.post_json( destination=destination, @@ -283,7 +283,7 @@ class TransportLayerClient(object): A dict containg the one-time keys. """ - path = PREFIX + "/client_keys/claim" + path = PREFIX + "/user/keys/claim" content = yield self.client.post_json( destination=destination, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fb59383ec..36f250e1a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -326,20 +326,20 @@ class FederationInviteServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): - PATH = "/client_keys/query" + PATH = "/user/keys/query" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_query(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_query_client_keys(origin, content) defer.returnValue((200, response)) class FederationClientKeysClaimServlet(BaseFederationServlet): - PATH = "/client_keys/claim" + PATH = "/user/keys/claim" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_claim(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_claim_client_keys(origin, content) defer.returnValue((200, response)) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 739a08ada..718928eed 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -202,7 +202,7 @@ class KeyQueryServlet(RestServlet): for device_id in device_ids: local_query.append((user_id, device_id)) else: - remote_queries.set_default(user.domain, {})[user_id] = list( + remote_queries.setdefault(user.domain, {})[user_id] = list( device_ids ) results = yield self.store.get_e2e_device_keys(local_query) @@ -218,7 +218,7 @@ class KeyQueryServlet(RestServlet): remote_result = yield self.federation.query_client_keys( destination, {"device_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["device_keys"].items(): if user_id in device_keys: json_result[user_id] = keys defer.returnValue((200, {"device_keys": json_result})) @@ -286,7 +286,7 @@ class OneTimeKeyServlet(RestServlet): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: - remote_queries.set_default(user.domain, {})[user_id] = ( + remote_queries.setdefault(user.domain, {})[user_id] = ( device_keys ) results = yield self.store.claim_e2e_one_time_keys(local_query) @@ -300,10 +300,10 @@ class OneTimeKeyServlet(RestServlet): } for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( + remote_result = yield self.federation.claim_client_keys( destination, {"one_time_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys From 4d6cb8814e134eba644afeed7bd49df0c7951342 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Aug 2015 09:32:23 +0100 Subject: [PATCH 05/59] Speed up event filtering (for ACL) logic --- synapse/handlers/federation.py | 6 +- synapse/handlers/message.py | 6 +- synapse/handlers/sync.py | 6 +- synapse/storage/_base.py | 10 ++- synapse/storage/state.py | 123 ++++++++++++++++++++++----------- 5 files changed, 105 insertions(+), 46 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f7155fd8d..22f534e49 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -230,7 +230,11 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) ) events_and_states = zip(events, states) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d6d4f097..765b14d99 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -138,7 +138,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) events_and_states = zip(events, states) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6cff6230c..8f58774b3 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -295,7 +295,11 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) events_and_states = zip(events, states) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd..7b76ee3b7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -71,6 +71,11 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -85,9 +90,10 @@ class Cache(object): if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) - if keyargs in self.cache: + val = self.cache.get(keyargs, self.sentinel) + if val is not self.sentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) raise KeyError() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 47bec6549..7e9bd232c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached from twisted.internet import defer +from synapse.util import unwrapFirstError from synapse.util.stringutils import random_string import logging @@ -206,62 +207,102 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + @cached(num_args=3, lru=True) + def _get_state_groups_from_group(self, room_id, group, types): def f(txn): - groups = set() - event_to_group = {} - for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - event_to_group[event_id] = group - groups.add(group) + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " room_id = ? AND state_group = ? AND (%s)" + ) % (" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),) - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) + args = [room_id, group] + args.extend([i for typ in types for i in typ]) + txn.execute(sql, args) - group_to_state_ids[group] = state_ids + return group, [ + r[0] + for r in txn.fetchall() + ] - return event_to_group, group_to_state_ids - - res = yield self.runInteraction( - "annotate_events_with_state_groups", + return self.runInteraction( + "_get_state_groups_from_group", f, ) - event_to_group, group_to_state_ids = res + @cached(num_args=3, lru=True, max_entries=100000) + def _get_state_for_event_id(self, room_id, event_id, types): + def f(txn): + type_and_state_sql = " OR ".join([ + "(type = ? AND state_key = ?)" + if typ[1] is not None + else "type = ?" + for typ in types + ]) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() - ], - consumeErrors=True, + sql = ( + "SELECT sg.event_id FROM state_groups_state as sg" + " INNER JOIN event_to_state_groups as e" + " ON e.state_group = sg.state_group" + " WHERE e.event_id = ? AND (%s)" + ) % (type_and_state_sql,) + + args = [event_id] + for typ, state_key in types: + args.extend( + [typ, state_key] if state_key is not None else [typ] + ) + txn.execute(sql, args) + + return event_id, [ + r[0] + for r in txn.fetchall() + ] + + return self.runInteraction( + "_get_state_for_event_id", + f, ) - state_dict = { - group: { + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids, types): + set_types = frozenset(types) + res = yield defer.gatherResults( + [ + self._get_state_for_event_id( + room_id, event_id, set_types, + ) + for event_id in event_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + event_to_state_ids = dict(res) + + event_dict = yield self._get_events( + [ + item + for lst in event_to_state_ids.values() + for item in lst + ], + get_prev_content=False + ).addCallback( + lambda evs: {ev.event_id: ev for ev in evs} + ) + + event_to_state = { + event_id: { (ev.type, ev.state_key): ev - for ev in state + for ev in [ + event_dict[state_id] + for state_id in state_ids + if state_id in event_dict + ] } - for group, state in state_list + for event_id, state_ids in event_to_state_ids.items() } defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) + event_to_state[event] for event in event_ids ]) From 413a4c289b0a7bc9655a6f8543ccf1375e2a9e34 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Aug 2015 11:08:07 +0100 Subject: [PATCH 06/59] Add comment --- synapse/storage/state.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7e9bd232c..91a5ae86a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -265,6 +265,21 @@ class StateStore(SQLBaseStore): @defer.inlineCallbacks def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ set_types = frozenset(types) res = yield defer.gatherResults( [ From e7768e77f56709897215c13ee3fd187550d20fd2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Aug 2015 15:56:56 +0100 Subject: [PATCH 07/59] Add basic dictionary cache --- synapse/storage/util/caches.py | 94 ++++++++++++++++++++++++++++++ tests/util/test_dict_cache.py | 101 +++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 synapse/storage/util/caches.py create mode 100644 tests/util/test_dict_cache.py diff --git a/synapse/storage/util/caches.py b/synapse/storage/util/caches.py new file mode 100644 index 000000000..0877cc79f --- /dev/null +++ b/synapse/storage/util/caches.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.lrucache import LruCache +from collections import namedtuple +import threading + + +DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) + + +class DictionaryCache(object): + + def __init__(self, name, max_entries=1000): + self.cache = LruCache(max_size=max_entries) + + self.name = name + self.sequence = 0 + self.thread = None + # caches_by_name[name] = self.cache + + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, dict_keys=None): + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + # cache_counter.inc_hits(self.name) + + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) + + # cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + + def invalidate(self, key): + self.check_thread() + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + def update(self, sequence, key, value, full=False): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + + def _update_or_insert(self, key, value): + entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry.value.update(value) + + def _insert(self, key, value): + self.cache[key] = DictionaryEntry(True, value) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py new file mode 100644 index 000000000..8cb9be658 --- /dev/null +++ b/tests/util/test_dict_cache.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer +from tests import unittest + +from synapse.storage.util.caches import DictionaryCache + + +class DictCacheTestCase(unittest.TestCase): + + def setUp(self): + self.cache = DictionaryCache("foobar") + + def test_simple_cache_hit_full(self): + key = "test_simple_cache_hit_full" + + v = self.cache.get(key) + self.assertEqual((False, {}), v) + + seq = self.cache.sequence + test_value = {"test": "test_simple_cache_hit_full"} + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key) + self.assertEqual(test_value, c.value) + + def test_simple_cache_hit_partial(self): + key = "test_simple_cache_hit_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test"]) + self.assertEqual(test_value, c.value) + + def test_simple_cache_miss_partial(self): + key = "test_simple_cache_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_miss_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({}, c.value) + + def test_simple_cache_hit_miss_partial(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + "test3": "test_simple_cache_hit_miss_partial3", + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) + + def test_multi_insert(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value_1 = { + "test": "test_simple_cache_hit_miss_partial", + } + self.cache.update(seq, key, test_value_1, full=False) + + seq = self.cache.sequence + test_value_2 = { + "test2": "test_simple_cache_hit_miss_partial2", + } + self.cache.update(seq, key, test_value_2, full=False) + + c = self.cache.get(key) + self.assertEqual( + { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + }, + c.value + ) From c67ba143fa1c5d9de0e3e998b9ec74c1f3342742 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Aug 2015 15:58:28 +0100 Subject: [PATCH 08/59] Move DictionaryCache --- synapse/{storage/util/caches.py => util/dictionary_cache.py} | 0 tests/util/test_dict_cache.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename synapse/{storage/util/caches.py => util/dictionary_cache.py} (100%) diff --git a/synapse/storage/util/caches.py b/synapse/util/dictionary_cache.py similarity index 100% rename from synapse/storage/util/caches.py rename to synapse/util/dictionary_cache.py diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 8cb9be658..79bc1225d 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -17,7 +17,7 @@ from twisted.internet import defer from tests import unittest -from synapse.storage.util.caches import DictionaryCache +from synapse.util.dictionary_cache import DictionaryCache class DictCacheTestCase(unittest.TestCase): From 07507643cb6a2fde1a87d229f8d77525627a0632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2015 15:06:51 +0100 Subject: [PATCH 09/59] Use dictionary cache to do group -> state fetching --- synapse/handlers/federation.py | 2 +- synapse/state.py | 10 +- synapse/storage/_base.py | 39 ++++--- synapse/storage/state.py | 195 ++++++++++++++++++++----------- synapse/storage/stream.py | 3 +- synapse/util/dictionary_cache.py | 54 +++++---- tests/test_state.py | 2 +- 7 files changed, 195 insertions(+), 110 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 22f534e49..90649af9e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -507,7 +507,7 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups([e]) + self.state_handler.resolve_state_groups(room_id, [e]) for e in event_ids ]) states = dict(zip(event_ids, [s[1] for s in states])) diff --git a/synapse/state.py b/synapse/state.py index 80da90a72..b5e5d7bbd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -96,7 +96,7 @@ class StateHandler(object): cache.ts = self.clock.time_msec() state = cache.state else: - res = yield self.resolve_state_groups(event_ids) + res = yield self.resolve_state_groups(room_id, event_ids) state = res[1] if event_type: @@ -155,13 +155,13 @@ class StateHandler(object): if event.is_state(): ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], ) group, curr_state, prev_state = ret @@ -180,7 +180,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def resolve_state_groups(self, event_ids, event_type=None, state_key=""): + def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -205,7 +205,7 @@ class StateHandler(object): ) state_groups = yield self.store.get_state_groups( - event_ids + room_id, event_ids ) logger.debug( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7b76ee3b7..803b9d599 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,6 +18,7 @@ from synapse.api.errors import StoreError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache +from synapse.util.dictionary_cache import DictionaryCache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator @@ -87,23 +88,33 @@ class Cache(object): ) def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + try: + if len(keyargs) != self.keylen: + raise ValueError("Expected a key to have %d items", self.keylen) - val = self.cache.get(keyargs, self.sentinel) - if val is not self.sentinel: - cache_counter.inc_hits(self.name) - return val + val = self.cache.get(keyargs, self.sentinel) + if val is not self.sentinel: + cache_counter.inc_hits(self.name) + return val - cache_counter.inc_misses(self.name) - raise KeyError() + cache_counter.inc_misses(self.name) + raise KeyError() + except KeyError: + raise + except: + logger.exception("Cache.get failed for %s" % (self.name,)) + raise def update(self, sequence, *args): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(*args) + except: + logger.exception("Cache.update failed for %s" % (self.name,)) + raise def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] @@ -327,6 +338,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 91a5ae86a..a967b3d44 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -45,52 +45,38 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, - ) - - state_list = yield defer.gatherResults( + event_and_groups = yield defer.gatherResults( [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) + for event_id in event_ids ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) - defer.returnValue(dict(state_list)) + groups = set(group for _, group in event_and_groups if group) + + group_to_state = yield defer.gatherResults( + [ + self._get_state_for_group( + group, + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state + }) @cached(num_args=1) def _fetch_events_for_group(self, key, events): @@ -207,16 +193,25 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3, lru=True) - def _get_state_groups_from_group(self, room_id, group, types): + @cached(num_args=2, lru=True, max_entries=10000) + def _get_state_groups_from_group(self, group, types): def f(txn): + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + sql = ( "SELECT event_id FROM state_groups_state WHERE" - " room_id = ? AND state_group = ? AND (%s)" - ) % (" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),) + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) - args = [room_id, group] - args.extend([i for typ in types for i in typ]) txn.execute(sql, args) return group, [ @@ -229,7 +224,7 @@ class StateStore(SQLBaseStore): f, ) - @cached(num_args=3, lru=True, max_entries=100000) + @cached(num_args=3, lru=True, max_entries=20000) def _get_state_for_event_id(self, room_id, event_id, types): def f(txn): type_and_state_sql = " OR ".join([ @@ -280,40 +275,33 @@ class StateStore(SQLBaseStore): deferred: A list of dicts corresponding to the event_ids given. The dicts are mappings from (type, state_key) -> state_events """ - set_types = frozenset(types) - res = yield defer.gatherResults( + event_and_groups = yield defer.gatherResults( [ - self._get_state_for_event_id( - room_id, event_id, set_types, - ) + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) for event_id in event_ids ], consumeErrors=True, ).addErrback(unwrapFirstError) - event_to_state_ids = dict(res) + groups = set(group for _, group in event_and_groups) - event_dict = yield self._get_events( + res = yield defer.gatherResults( [ - item - for lst in event_to_state_ids.values() - for item in lst + self._get_state_for_group( + group, types + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups ], - get_prev_content=False - ).addCallback( - lambda evs: {ev.event_id: ev for ev in evs} - ) + consumeErrors=True, + ).addErrback(unwrapFirstError) + + group_to_state = dict(res) event_to_state = { - event_id: { - (ev.type, ev.state_key): ev - for ev in [ - event_dict[state_id] - for state_id in state_ids - if state_id in event_dict - ] - } - for event_id, state_ids in event_to_state_ids.items() + event_id: group_to_state[group] + for event_id, group in event_and_groups } defer.returnValue([ @@ -321,6 +309,79 @@ class StateStore(SQLBaseStore): for event in event_ids ]) + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @defer.inlineCallbacks + def _get_state_for_group(self, group, types=None): + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + if types is not None: + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + if is_all and types is None: + defer.returnValue(state_dict) + + if is_all or (types is not None and not missing_types): + def include(typ, state_key): + sentinel = object() + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + defer.returnValue({ + k: v + for k, v in state_dict.items() + if include(k[0], k[1]) + }) + + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence + _, state_ids = yield self._get_state_groups_from_group( + group, + frozenset(types) if types else None + ) + state_events = yield self._get_events(state_ids, get_prev_content=False) + state_dict = { + (e.type, e.state_key): e + for e in state_events + } + + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + defer.returnValue(state_dict) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index af45fc561..9db259d5f 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -300,8 +300,7 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False, from_token=None): + def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/dictionary_cache.py index 0877cc79f..38b131677 100644 --- a/synapse/util/dictionary_cache.py +++ b/synapse/util/dictionary_cache.py @@ -16,6 +16,10 @@ from synapse.util.lrucache import LruCache from collections import namedtuple import threading +import logging + + +logger = logging.getLogger(__name__) DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) @@ -47,21 +51,25 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - # cache_counter.inc_hits(self.name) + try: + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + # cache_counter.inc_hits(self.name) - if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) - else: - return DictionaryEntry(entry.full, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) - # cache_counter.inc_misses(self.name) - return DictionaryEntry(False, {}) + # cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + except: + logger.exception("get failed") + raise def invalidate(self, key): self.check_thread() @@ -77,14 +85,18 @@ class DictionaryCache(object): self.cache.clear() def update(self, sequence, key, value, full=False): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - if full: - self._insert(key, value) - else: - self._update_or_insert(key, value) + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + except: + logger.exception("update failed") + raise def _update_or_insert(self, key, value): entry = self.cache.setdefault(key, DictionaryEntry(False, {})) diff --git a/tests/test_state.py b/tests/test_state.py index fea25f702..584535875 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -69,7 +69,7 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) From fe994e728fba5ef43e0436f6472cd94d6ce3c902 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 10:17:38 +0100 Subject: [PATCH 10/59] Store absence of state in cache --- synapse/storage/state.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 48a402355..e924258d1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -234,7 +234,8 @@ class StateStore(SQLBaseStore): ]) sql = ( - "SELECT sg.event_id FROM state_groups_state as sg" + "SELECT e.event_id, sg.state_group, sg.event_id" + " FROM state_groups_state as sg" " INNER JOIN event_to_state_groups as e" " ON e.state_group = sg.state_group" " WHERE e.event_id = ? AND (%s)" @@ -342,8 +343,9 @@ class StateStore(SQLBaseStore): defer.returnValue(state_dict) if is_all or (types is not None and not missing_types): + sentinel = object() + def include(typ, state_key): - sentinel = object() valid_state_keys = type_to_key.get(typ, sentinel) if valid_state_keys is sentinel: return False @@ -356,20 +358,24 @@ class StateStore(SQLBaseStore): defer.returnValue({ k: v for k, v in state_dict.items() - if include(k[0], k[1]) + if v and include(k[0], k[1]) }) # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence _, state_ids = yield self._get_state_groups_from_group( group, - frozenset(types) if types else None + frozenset(missing_types) if types else None ) state_events = yield self._get_events(state_ids, get_prev_content=False) state_dict = { + key: None + for key in missing_types + } + state_dict.update({ (e.type, e.state_key): e for e in state_events - } + }) # Update the cache self._state_group_cache.update( @@ -379,7 +385,11 @@ class StateStore(SQLBaseStore): full=(types is None), ) - defer.returnValue(state_dict) + defer.returnValue({ + key: value + for key, value in state_dict.items() + if value + }) def _make_group_id(clock): From b8e386db59eea6a59b8338acfd8ea42632d539be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 11:52:21 +0100 Subject: [PATCH 11/59] Change Cache to not use *args in its interface --- synapse/storage/__init__.py | 4 +- synapse/storage/_base.py | 85 ++++++++++++----------------- synapse/storage/directory.py | 4 +- synapse/storage/event_federation.py | 4 +- synapse/storage/events.py | 21 +++---- synapse/storage/keys.py | 2 +- synapse/storage/presence.py | 4 +- synapse/storage/push_rule.py | 16 +++--- synapse/storage/registration.py | 2 +- synapse/storage/roommember.py | 6 +- tests/storage/test__base.py | 12 ++-- 11 files changed, 73 insertions(+), 87 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 71d5d9250..1a6a8a376 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d4751769e..32089b05e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -58,6 +58,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -74,11 +77,6 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache - class Sentinel(object): - __slots__ = [] - - self.sentinel = Sentinel() - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -89,52 +87,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - try: - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val - val = self.cache.get(keyargs, self.sentinel) - if val is not self.sentinel: - cache_counter.inc_hits(self.name) - return val + cache_counter.inc_misses(self.name) - cache_counter.inc_misses(self.name) + if default is _CacheSentinel: raise KeyError() - except KeyError: - raise - except: - logger.exception("Cache.get failed for %s" % (self.name,)) - raise + else: + return default - def update(self, sequence, *args): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - except: - logger.exception("Cache.update failed for %s" % (self.name,)) - raise - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def update(self, sequence, keyargs, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -185,20 +169,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -219,7 +204,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -227,19 +212,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf861..f3947bbe8 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 45b86c94e..910b6598a 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore): for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed7ea3880..5b6491802 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): if not context.rejected: txn.call_after( self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + (event.room_id, event.type, event.state_key,) + ) if event.type in [EventTypes.Name, EventTypes.Aliases]: txn.call_after( self.get_room_name_and_aliases.invalidate, - event.room_id + (event.room_id,) ) self._simple_upsert_txn( @@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3f98f0cd..49b8e37cf 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate(server_name) + self.get_all_server_verify_keys.invalidate((server_name,)) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce..576cf670c 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a220f3632..9b88ca7b3 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be..4eaa088b3 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 55dd3f6cf..9f14f38f2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8fa305d18..abee2f631 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase): self.assertEquals(self.cache.get("foo"), 123) def test_invalidate(self): - self.cache.prefill("foo", 123) - self.cache.invalidate("foo") + self.cache.prefill(("foo",), 123) + self.cache.invalidate(("foo",)) failed = False try: - self.cache.get("foo") + self.cache.get(("foo",)) except KeyError: failed = True @@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 1) - a.func.invalidate("foo") + a.func.invalidate(("foo",)) yield a.func("foo") @@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): def func(self, key): return key - A().func.invalidate("what") + A().func.invalidate(("what",)) @defer.inlineCallbacks def test_max_entries(self): @@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a = A() - a.func.prefill("foo", ObservableDeferred(d)) + a.func.prefill(("foo",), ObservableDeferred(d)) self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) From b3768ec10ad7a96f0d7f9d774c44fe8ade1f80e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 13:41:05 +0100 Subject: [PATCH 12/59] Remove unncessary cache --- synapse/storage/state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e924258d1..5588c9e69 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -192,7 +192,6 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=2, lru=True, max_entries=10000) def _get_state_groups_from_group(self, group, types): def f(txn): if types is not None: From b2c7bd4b098bea86189b2c8323265e23673d257d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 14:42:34 +0100 Subject: [PATCH 13/59] Cache get_recent_events_for_room --- synapse/storage/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 9db259d5f..b59fe8100 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,7 +35,7 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore +from ._base import SQLBaseStore, cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -299,7 +299,7 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=4) def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback From ffdb8c382860ce2e351614a91c2ce07a91c61455 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 18:13:48 +0100 Subject: [PATCH 14/59] Don't be too enthusiatic with defer.gatherResults --- synapse/handlers/message.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 765b14d99..11c736f72 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -405,10 +405,14 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") - yield defer.gatherResults( - [handle_room(e) for e in room_list], - consumeErrors=True - ).addErrback(unwrapFirstError) + # Only do N rooms at once + n = 5 + d_list = [handle_room(e) for e in room_list] + for ds in [d_list[i:i + n] for i in range(0, len(d_list), n)]: + yield defer.gatherResults( + ds, + consumeErrors=True + ).addErrback(unwrapFirstError) ret = { "rooms": rooms_ret, From 02118901347ab5067fbb47169b36a8424e671bfb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 18:14:49 +0100 Subject: [PATCH 15/59] Implement a CacheListDescriptor --- synapse/storage/_base.py | 106 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 32089b05e..556aa3b52 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,6 +16,7 @@ import logging from synapse.api.errors import StoreError from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -231,6 +232,101 @@ class CacheDescriptor(object): return wrapped +class CacheListDescriptor(object): + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + def cached(max_entries=1000, num_args=1, lru=True): return lambda orig: CacheDescriptor( orig, @@ -250,6 +346,16 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): ) +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() From 9eb5b23d3aeb0ddfb91cba0f36155dd8dcfbe8a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 18:15:30 +0100 Subject: [PATCH 16/59] Batch up various DB requests for event -> state --- synapse/storage/state.py | 219 +++++++++++++++++++++++++-------------- 1 file changed, 142 insertions(+), 77 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5588c9e69..a04731ae1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached, cachedInlineCallbacks +from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList from twisted.internet import defer -from synapse.util import unwrapFirstError from synapse.util.stringutils import random_string import logging @@ -50,32 +49,20 @@ class StateStore(SQLBaseStore): The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - event_and_groups = yield defer.gatherResults( - [ - self._get_state_group_for_event( - room_id, event_id, - ).addCallback(lambda group, event_id: (event_id, group), event_id) - for event_id in event_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) - groups = set(group for _, group in event_and_groups if group) + groups = set(event_to_groups.values()) - group_to_state = yield defer.gatherResults( - [ - self._get_state_for_group( - group, - ).addCallback(lambda state_dict, group: (group, state_dict), group) - for group in groups - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + group_to_state = yield self._get_state_for_groups(groups) defer.returnValue({ group: state_map.values() - for group, state_map in group_to_state + for group, state_map in group_to_state.items() }) @cached(num_args=1) @@ -212,17 +199,48 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) - return group, [ - r[0] - for r in txn.fetchall() - ] + return [r[0] for r in txn.fetchall()] return self.runInteraction( "_get_state_groups_from_group", f, ) - @cached(num_args=3, lru=True, max_entries=20000) + def _get_state_groups_from_groups(self, groups_and_types): + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [ + r[0] + for r in txn.fetchall() + ] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + + @cached(num_args=3, lru=True, max_entries=10000) def _get_state_for_event_id(self, room_id, event_id, types): def f(txn): type_and_state_sql = " OR ".join([ @@ -274,33 +292,19 @@ class StateStore(SQLBaseStore): deferred: A list of dicts corresponding to the event_ids given. The dicts are mappings from (type, state_key) -> state_events """ - event_and_groups = yield defer.gatherResults( - [ - self._get_state_group_for_event( - room_id, event_id, - ).addCallback(lambda group, event_id: (event_id, group), event_id) - for event_id in event_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) - groups = set(group for _, group in event_and_groups) + groups = set(event_to_groups.values()) - res = yield defer.gatherResults( - [ - self._get_state_for_group( - group, types - ).addCallback(lambda state_dict, group: (group, state_dict), group) - for group in groups - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - group_to_state = dict(res) + group_to_state = yield self._get_state_for_groups( + groups, types + ) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_and_groups + for event_id, group in event_to_groups.items() } defer.returnValue([ @@ -320,8 +324,29 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @defer.inlineCallbacks - def _get_state_for_group(self, group, types=None): + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + def f(txn): + results = {} + for event_id in event_ids: + results[event_id] = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + ) + + return results + + return self.runInteraction( + "_get_state_group_for_events", + f, + ) + + def _get_state_for_group_from_cache(self, group, types=None): is_all, state_dict = self._state_group_cache.get(group) type_to_key = {} @@ -339,7 +364,7 @@ class StateStore(SQLBaseStore): missing_types.add((typ, state_key)) if is_all and types is None: - defer.returnValue(state_dict) + return state_dict, missing_types if is_all or (types is not None and not missing_types): sentinel = object() @@ -354,41 +379,81 @@ class StateStore(SQLBaseStore): return True return False - defer.returnValue({ + return { k: v for k, v in state_dict.items() if v and include(k[0], k[1]) - }) + }, missing_types + + return {}, missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + results = {} + missing_groups_and_types = [] + for group in groups: + state_dict, missing_types = self._get_state_for_group_from_cache( + group, types + ) + + if types is not None and not missing_types: + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + else: + missing_groups_and_types.append(( + group, + missing_types if types else None + )) + + if not missing_groups_and_types: + defer.returnValue(results) # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence - _, state_ids = yield self._get_state_groups_from_group( - group, - frozenset(missing_types) if types else None + + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types ) - state_events = yield self._get_events(state_ids, get_prev_content=False) - state_dict = { - key: None - for key in missing_types - } - state_dict.update({ - (e.type, e.state_key): e + + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False + ) + + state_events = { + e.event_id: e for e in state_events - }) + } - # Update the cache - self._state_group_cache.update( - cache_seq_num, - key=group, - value=state_dict, - full=(types is None), - ) + for group, state_ids in group_state_dict.items(): + state_dict = { + key: None + for key in missing_types + } + evs = [state_events[e_id] for e_id in state_ids] + state_dict.update({ + (e.type, e.state_key): e + for e in evs + }) - defer.returnValue({ - key: value - for key, value in state_dict.items() - if value - }) + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + + defer.returnValue(results) def _make_group_id(clock): From 5119e416e855e1c5d78e8bd711bcbac2e15ee5f6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 10:05:30 +0100 Subject: [PATCH 17/59] Line length --- synapse/storage/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 478b38286..a19364985 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -323,7 +323,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2) + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", + num_args=2) def _get_state_group_for_events(self, room_id, event_ids): def f(txn): results = {} From aa88582e008550aa5a341ebd8c31279a63074df9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 10:08:15 +0100 Subject: [PATCH 18/59] Do bounds check --- synapse/storage/state.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a19364985..06f0ab3ff 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -433,7 +433,10 @@ class StateStore(SQLBaseStore): key: None for key in missing_types } - evs = [state_events[e_id] for e_id in state_ids] + evs = [ + state_events[e_id] for e_id in state_ids + if e_id in state_events # This can happen if event is rejected. + ] state_dict.update({ (e.type, e.state_key): e for e in evs From dcefac3b0637673179e4a107b8cdfb844ba3a6c9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 14:16:24 +0100 Subject: [PATCH 19/59] Comments --- synapse/util/dictionary_cache.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/dictionary_cache.py index 38b131677..c7564cdf0 100644 --- a/synapse/util/dictionary_cache.py +++ b/synapse/util/dictionary_cache.py @@ -26,6 +26,9 @@ DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) class DictionaryCache(object): + """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. + fetching a subset of dictionary keys for a particular key. + """ def __init__(self, name, max_entries=1000): self.cache = LruCache(max_size=max_entries) From bb0a475c30e63a6d9d749a1f8587582904d3ff92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 14:27:38 +0100 Subject: [PATCH 20/59] Comments --- synapse/storage/_base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 27713e8b7..826c393cd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -234,6 +234,12 @@ class CacheDescriptor(object): class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): self.orig = orig From 2c019eea11ce83cbb64ae2ea75e7e7aa260e5f48 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 14:44:41 +0100 Subject: [PATCH 21/59] Remove unused function --- synapse/storage/state.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 06f0ab3ff..b83f644fd 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -65,13 +65,6 @@ class StateStore(SQLBaseStore): for group, state_map in group_to_state.items() }) - def _fetch_events_for_group(self, key, events): - return self._get_events( - events, get_prev_content=False - ).addCallback( - lambda evs: (key, evs) - ) - def _store_state_groups_txn(self, txn, event, context): return self._store_mult_state_groups_txn(txn, [(event, context)]) From 017b798e4f3dedaf0c02fba4d5ec53b7130d6ef2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 15:01:06 +0100 Subject: [PATCH 22/59] Clean up StateStore --- synapse/storage/state.py | 73 ++++++---------------------------------- 1 file changed, 11 insertions(+), 62 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b83f644fd..64c5ae992 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -171,34 +171,9 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - def _get_state_groups_from_group(self, group, types): - def f(txn): - if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) - else: - where_clause = "" - - sql = ( - "SELECT event_id FROM state_groups_state WHERE" - " state_group = ? %s" - ) % (where_clause,) - - args = [group] - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - - return [r[0] for r in txn.fetchall()] - - return self.runInteraction( - "_get_state_groups_from_group", - f, - ) - def _get_state_groups_from_groups(self, groups_and_types): + """Returns dictionary state_group -> state event ids + """ def f(txn): results = {} for group, types in groups_and_types: @@ -232,41 +207,6 @@ class StateStore(SQLBaseStore): f, ) - @cached(num_args=3, lru=True, max_entries=10000) - def _get_state_for_event_id(self, room_id, event_id, types): - def f(txn): - type_and_state_sql = " OR ".join([ - "(type = ? AND state_key = ?)" - if typ[1] is not None - else "type = ?" - for typ in types - ]) - - sql = ( - "SELECT e.event_id, sg.state_group, sg.event_id" - " FROM state_groups_state as sg" - " INNER JOIN event_to_state_groups as e" - " ON e.state_group = sg.state_group" - " WHERE e.event_id = ? AND (%s)" - ) % (type_and_state_sql,) - - args = [event_id] - for typ, state_key in types: - args.extend( - [typ, state_key] if state_key is not None else [typ] - ) - txn.execute(sql, args) - - return event_id, [ - r[0] - for r in txn.fetchall() - ] - - return self.runInteraction( - "_get_state_for_event_id", - f, - ) - @defer.inlineCallbacks def get_state_for_events(self, room_id, event_ids, types): """Given a list of event_ids and type tuples, return a list of state @@ -319,6 +259,8 @@ class StateStore(SQLBaseStore): @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2) def _get_state_group_for_events(self, room_id, event_ids): + """Returns mapping event_id -> state_group + """ def f(txn): results = {} for event_id in event_ids: @@ -340,6 +282,8 @@ class StateStore(SQLBaseStore): ) def _get_state_for_group_from_cache(self, group, types=None): + """Checks if group is in cache. See `_get_state_for_groups` + """ is_all, state_dict = self._state_group_cache.get(group) type_to_key = {} @@ -382,6 +326,11 @@ class StateStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): + """Given list of groups returns dict of group -> list of state events + with matching types. `types` is a list of `(type, state_key)`, where + a `state_key` of None matches all state_keys. If `types` is None then + all events are returned. + """ results = {} missing_groups_and_types = [] for group in groups: From 10b874067b2755a5fce751ab5eb02da26b1e5eaa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 09:12:41 +0100 Subject: [PATCH 23/59] Fix state cache --- synapse/storage/state.py | 83 +++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 35 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 64c5ae992..19b16ed40 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -283,6 +283,9 @@ class StateStore(SQLBaseStore): def _get_state_for_group_from_cache(self, group, types=None): """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the + list of types that aren't in the cache for that group. """ is_all, state_dict = self._state_group_cache.get(group) @@ -300,29 +303,31 @@ class StateStore(SQLBaseStore): if (typ, state_key) not in state_dict: missing_types.add((typ, state_key)) - if is_all and types is None: - return state_dict, missing_types + if is_all: + missing_types = set() + if types is None: + return state_dict, set(), True - if is_all or (types is not None and not missing_types): - sentinel = object() + sentinel = object() - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return False - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True + def include(typ, state_key): + if types is None: + return True + + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False - return { - k: v - for k, v in state_dict.items() - if v and include(k[0], k[1]) - }, missing_types - - return {}, missing_types + return { + k: v + for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, not missing_types and types is not None @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): @@ -333,25 +338,28 @@ class StateStore(SQLBaseStore): """ results = {} missing_groups_and_types = [] - for group in groups: - state_dict, missing_types = self._get_state_for_group_from_cache( + for group in set(groups): + state_dict, missing_types, got_all = self._get_state_for_group_from_cache( group, types ) - if types is not None and not missing_types: - results[group] = { - key: value - for key, value in state_dict.items() - if value - } - else: + results[group] = state_dict + + if not got_all: missing_groups_and_types.append(( group, missing_types if types else None )) if not missing_groups_and_types: - defer.returnValue(results) + defer.returnValue({ + k: { + key: ev + for key, ev in state.items() + if ev + } + for k, state in results.items() + }) # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence @@ -371,10 +379,15 @@ class StateStore(SQLBaseStore): } for group, state_ids in group_state_dict.items(): - state_dict = { - key: None - for key in missing_types - } + if types: + state_dict = { + key: None + for key in types + } + state_dict.update(results[group]) + else: + state_dict = results[group] + evs = [ state_events[e_id] for e_id in state_ids if e_id in state_events # This can happen if event is rejected. @@ -392,11 +405,11 @@ class StateStore(SQLBaseStore): full=(types is None), ) - results[group] = { + results[group].update({ key: value for key, value in state_dict.items() if value - } + }) defer.returnValue(results) From 1b994a97dd201d0f122a416f28dbbf1136304412 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 10:41:40 +0100 Subject: [PATCH 24/59] Fix application of ACLs --- synapse/handlers/federation.py | 11 +++++------ synapse/handlers/message.py | 16 ++++++++++++---- synapse/handlers/sync.py | 17 +++++++++++++---- synapse/storage/state.py | 6 +++--- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 90649af9e..2bfd0a40e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -229,7 +229,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - states = yield self.store.get_state_for_events( + event_to_state = yield self.store.get_state_for_events( room_id, frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -237,8 +237,6 @@ class FederationHandler(BaseHandler): ) ) - events_and_states = zip(events, states) - def redact_disallowed(event_and_state): event, state = event_and_state @@ -275,9 +273,10 @@ class FederationHandler(BaseHandler): return event - res = map(redact_disallowed, events_and_states) - - logger.info("_filter_events_for_server %r", res) + res = map(redact_disallowed, [ + (e, event_to_state[e.event_id]) + for e in events + ]) defer.returnValue(res) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 11c736f72..95a8f05c0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -137,7 +137,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): - states = yield self.store.get_state_for_events( + event_id_to_state = yield self.store.get_state_for_events( room_id, frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -145,7 +145,8 @@ class MessageHandler(BaseHandler): ) ) - events_and_states = zip(events, states) + for ev, state in event_id_to_state.items(): + logger.info("event_id: %r, state: %r", ev, state) def allowed(event_and_state): event, state = event_and_state @@ -179,10 +180,17 @@ class MessageHandler(BaseHandler): return True - events_and_states = filter(allowed, events_and_states) + event_and_state = filter( + allowed, + [ + (e, event_id_to_state[e.event_id]) + for e in events + ] + ) + defer.returnValue([ ev - for ev, _ in events_and_states + for ev, _ in event_and_state ]) @defer.inlineCallbacks diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8f58774b3..9a97bff84 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -294,7 +294,7 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): - states = yield self.store.get_state_for_events( + event_id_to_state = yield self.store.get_state_for_events( room_id, frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -302,7 +302,8 @@ class SyncHandler(BaseHandler): ) ) - events_and_states = zip(events, states) + for ev, state in event_id_to_state.items(): + logger.info("event_id: %r, state: %r", ev, state) def allowed(event_and_state): event, state = event_and_state @@ -335,10 +336,18 @@ class SyncHandler(BaseHandler): return membership == Membership.INVITE return True - events_and_states = filter(allowed, events_and_states) + + event_and_state = filter( + allowed, + [ + (e, event_id_to_state[e.event_id]) + for e in events + ] + ) + defer.returnValue([ ev - for ev, _ in events_and_states + for ev, _ in event_and_state ]) @defer.inlineCallbacks diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 19b16ed40..a43853007 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -239,10 +239,10 @@ class StateStore(SQLBaseStore): for event_id, group in event_to_groups.items() } - defer.returnValue([ - event_to_state[event] + defer.returnValue({ + event: event_to_state[event] for event in event_ids - ]) + }) @cached(num_args=2, lru=True, max_entries=100000) def _get_state_group_for_event(self, room_id, event_id): From dc8399ee0059bb4ee93fb7c755bc36ade16230a8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:30:59 +0100 Subject: [PATCH 25/59] Remove debug loggers --- synapse/handlers/message.py | 3 --- synapse/handlers/sync.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 95a8f05c0..b941312ef 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -145,9 +145,6 @@ class MessageHandler(BaseHandler): ) ) - for ev, state in event_id_to_state.items(): - logger.info("event_id: %r, state: %r", ev, state) - def allowed(event_and_state): event, state = event_and_state diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9a97bff84..d960078e7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -302,9 +302,6 @@ class SyncHandler(BaseHandler): ) ) - for ev, state in event_id_to_state.items(): - logger.info("event_id: %r, state: %r", ev, state) - def allowed(event_and_state): event, state = event_and_state From 4762c276cb460044099a98a4343c3f2b0ce2abe4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:33:41 +0100 Subject: [PATCH 26/59] Docs --- synapse/storage/_base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 826c393cd..bdca61d2e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -241,6 +241,15 @@ class CacheListDescriptor(object): """ def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ self.orig = orig if inlineCallbacks: From 6eaa116867ebfd22718505d1aa29dce4679c87be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:35:24 +0100 Subject: [PATCH 27/59] Comment --- synapse/storage/_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bdca61d2e..4d86fe7c7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -288,6 +288,8 @@ class CacheListDescriptor(object): keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) cached = {} missing = [] for arg in list_args: From 53a817518bd27e6838c747a03d605165f96dbd12 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:40:40 +0100 Subject: [PATCH 28/59] Comments --- synapse/storage/_base.py | 2 ++ synapse/storage/state.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 4d86fe7c7..e5441aafb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -315,6 +315,8 @@ class CacheListDescriptor(object): ret_d = ObservableDeferred(ret_d) + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. for arg in missing: observer = ret_d.observe() observer.addCallback(lambda r, arg: r[arg], arg) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a43853007..ea5fa9de7 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -173,6 +173,9 @@ class StateStore(SQLBaseStore): def _get_state_groups_from_groups(self, groups_and_types): """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) """ def f(txn): results = {} @@ -284,8 +287,11 @@ class StateStore(SQLBaseStore): def _get_state_for_group_from_cache(self, group, types=None): """Checks if group is in cache. See `_get_state_for_groups` - Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the - list of types that aren't in the cache for that group. + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group, or None if `types` is None. `got_all` is a bool indicating if + we successfully retrieved all requests state from the cache, if False + we need to query the DB for the missing state. """ is_all, state_dict = self._state_group_cache.get(group) @@ -323,11 +329,13 @@ class StateStore(SQLBaseStore): return True return False + got_all = not (missing_types or types is None) + return { k: v for k, v in state_dict.items() if include(k[0], k[1]) - }, missing_types, not missing_types and types is not None + }, missing_types, got_all @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): From 2df8dd9b37f26e3ad0d3647a1e78804a85d48c0c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 17:59:32 +0100 Subject: [PATCH 29/59] Move all the caches into their own package, synapse.util.caches --- synapse/federation/federation_client.py | 2 +- synapse/state.py | 2 +- synapse/storage/_base.py | 335 +--------------- synapse/storage/directory.py | 3 +- synapse/storage/event_federation.py | 3 +- synapse/storage/keys.py | 3 +- synapse/storage/presence.py | 3 +- synapse/storage/push_rule.py | 3 +- synapse/storage/receipts.py | 3 +- synapse/storage/registration.py | 3 +- synapse/storage/room.py | 3 +- synapse/storage/roommember.py | 3 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 3 +- synapse/storage/transactions.py | 3 +- synapse/util/caches/__init__.py | 14 + synapse/util/caches/descriptors.py | 359 ++++++++++++++++++ synapse/util/{ => caches}/dictionary_cache.py | 2 +- synapse/util/{ => caches}/expiringcache.py | 0 synapse/util/{ => caches}/lrucache.py | 0 tests/storage/test__base.py | 2 +- tests/util/test_dict_cache.py | 2 +- tests/util/test_lrucache.py | 4 +- 23 files changed, 408 insertions(+), 352 deletions(-) create mode 100644 synapse/util/caches/__init__.py create mode 100644 synapse/util/caches/descriptors.py rename synapse/util/{ => caches}/dictionary_cache.py (98%) rename synapse/util/{ => caches}/expiringcache.py (100%) rename synapse/util/{ => caches}/lrucache.py (100%) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb..58a6d6a0e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,7 +23,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics diff --git a/synapse/state.py b/synapse/state.py index b5e5d7bbd..1fe4d066b 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e5441aafb..1444767a5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,27 +15,22 @@ import logging from synapse.api.errors import StoreError -from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools -import inspect import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -_CacheSentinel = object() - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=True): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) - return val - - cache_counter.inc_misses(self.name) - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def update(self, sequence, key, value): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[key] = value - - def invalidate(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise ValueError("keyargs must be a tuple.") - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - This caches deferreds, rather than the results themselves. Deferreds that - fail are removed from the cache. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, - inlineCallbacks=False): - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - self.cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - try: - cached_result_d = self.cache.get(cache_key) - - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) - - return observer - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence - - ret = defer.maybeDeferred( - self.function_to_call, - obj, *args, **kwargs - ) - - def onErr(f): - self.cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) - - return ret.observe() - - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.prefill = self.cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -class CacheListDescriptor(object): - """Wraps an existing cache to support bulk fetching of keys. - - Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. - """ - - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): - """ - Args: - orig (function) - cache (Cache) - list_name (str): Name of the argument which is the bulk lookup list - num_args (int) - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks - """ - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args - self.list_name = list_name - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - self.list_pos = self.arg_names.index(self.list_name) - - self.cache = cache - - self.sentinel = object() - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - if self.list_name not in self.arg_names: - raise Exception( - "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] - list_args = arg_dict[self.list_name] - - # cached is a dict arg -> deferred, where deferred results in a - # 2-tuple (`arg`, `result`) - cached = {} - missing = [] - for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg - - try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res - except KeyError: - missing.append(arg) - - if missing: - sequence = self.cache.sequence - args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing - - ret_d = defer.maybeDeferred( - self.function_to_call, - **args_to_call - ) - - ret_d = ObservableDeferred(ret_d) - - # We need to create deferreds for each arg in the list so that - # we can insert the new deferred into the cache. - for arg in missing: - observer = ret_d.observe() - observer.addCallback(lambda r, arg: r[arg], arg) - - observer = ObservableDeferred(observer) - - key = list(keyargs) - key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) - - def invalidate(f, key): - self.cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) - - res = observer.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - - cached[arg] = res - - return defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=True): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - - -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru, - inlineCallbacks=True, - ) - - -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): - return lambda orig: CacheListDescriptor( - orig, - cache=cache, - list_name=list_name, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index f3947bbe8..d92028ea4 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.errors import SynapseError diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 910b6598a..25cc84eb9 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from syutil.base64util import encode_base64 import logging diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 49b8e37cf..ffd6daa88 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cachedInlineCallbacks +from _base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 576cf670c..4f91a2b87 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from twisted.internet import defer diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9b88ca7b3..5305b7e12 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import logging diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b79d6683c..cac1a5657 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4eaa088b3..aa446f94c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached class RegistrationStore(SQLBaseStore): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dd5bc2c8f..5e07b7e0e 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import collections import logging diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9f14f38f2..8eee2dfbc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,8 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.constants import Membership from synapse.types import UserID diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ea5fa9de7..79c3b82d9 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b59fe8100..d7fe423f5 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,7 +35,8 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 624da4a9d..c8c7e6591 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from collections import namedtuple diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py new file mode 100644 index 000000000..1a84d94cd --- /dev/null +++ b/synapse/util/caches/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py new file mode 100644 index 000000000..82dd09cf5 --- /dev/null +++ b/synapse/util/caches/descriptors.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError +from synapse.util.caches.lrucache import LruCache +import synapse.metrics + +from twisted.internet import defer + +from collections import OrderedDict + +import functools +import inspect +import threading + +logger = logging.getLogger(__name__) + + +DEBUG_CACHES = False + +metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +caches_by_name = {} +cache_counter = metrics.register_cache( + "cache", + lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, + labels=["name"], +) + + +_CacheSentinel = object() + + +class Cache(object): + + def __init__(self, name, max_entries=1000, keylen=1, lru=True): + if lru: + self.cache = LruCache(max_size=max_entries) + self.max_entries = None + else: + self.cache = OrderedDict() + self.max_entries = max_entries + + self.name = name + self.keylen = keylen + self.sequence = 0 + self.thread = None + caches_by_name[name] = self.cache + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val + + cache_counter.inc_misses(self.name) + + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(key, value) + + def prefill(self, key, value): + if self.max_entries is not None: + while len(self.cache) >= self.max_entries: + self.cache.popitem(last=False) + + self.cache[key] = value + + def invalidate(self, key): + self.check_thread() + if not isinstance(key, tuple): + raise ValueError("keyargs must be a tuple.") + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + +class CacheDescriptor(object): + """ A method decorator that applies a memoizing cache around the function. + + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + + The function is presumed to take zero or more arguments, which are used in + a tuple as the key for the cache. Hits are served directly from the cache; + misses use the function body to generate the value. + + The wrapped function has an additional member, a callable called + "invalidate". This can be used to remove individual entries from the cache. + + The wrapped function has another additional callable, called "prefill", + which can be used to insert values into the cache specifically, without + calling the calculation function. + """ + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + try: + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() + if DEBUG_CACHES: + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer + except KeyError: + # Get the sequence number of the cache before reading from the + # database so that we can tell if the cache is invalidated + # while the SELECT is executing (SYN-369) + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) + + return ret.observe() + + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +def cached(max_entries=1000, num_args=1, lru=True): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) + + +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py similarity index 98% rename from synapse/util/dictionary_cache.py rename to synapse/util/caches/dictionary_cache.py index c7564cdf0..26d464f4f 100644 --- a/synapse/util/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache from collections import namedtuple import threading import logging diff --git a/synapse/util/expiringcache.py b/synapse/util/caches/expiringcache.py similarity index 100% rename from synapse/util/expiringcache.py rename to synapse/util/caches/expiringcache.py diff --git a/synapse/util/lrucache.py b/synapse/util/caches/lrucache.py similarity index 100% rename from synapse/util/lrucache.py rename to synapse/util/caches/lrucache.py diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index abee2f631..e72cace8f 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.util.async import ObservableDeferred -from synapse.storage._base import Cache, cached +from synapse.util.caches.descriptors import Cache, cached class CacheTestCase(unittest.TestCase): diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 79bc1225d..54ff26cd9 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -17,7 +17,7 @@ from twisted.internet import defer from tests import unittest -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache class DictCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index ab934bf92..fc5a90432 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -16,7 +16,7 @@ from .. import unittest -from synapse.util.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache class LruCacheTestCase(unittest.TestCase): @@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop("key"), None) - - From 4807616e1615bdaaee56f800ba682d0d019de610 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 10:13:35 +0100 Subject: [PATCH 30/59] Wire up the dictionarycache to the metrics --- synapse/util/caches/__init__.py | 13 +++++++ synapse/util/caches/descriptors.py | 17 ++------ synapse/util/caches/dictionary_cache.py | 52 +++++++++++-------------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 1a84d94cd..da0e06a46 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -12,3 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import synapse.metrics + +DEBUG_CACHES = False + +metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +caches_by_name = {} +cache_counter = metrics.register_cache( + "cache", + lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, + labels=["name"], +) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 82dd09cf5..c99fda849 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -import synapse.metrics + +from . import caches_by_name, DEBUG_CACHES, cache_counter from twisted.internet import defer @@ -30,18 +31,6 @@ import threading logger = logging.getLogger(__name__) -DEBUG_CACHES = False - -metrics = synapse.metrics.get_metrics_for("synapse.util.caches") - -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - _CacheSentinel = object() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 26d464f4f..e69adf62f 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -15,6 +15,7 @@ from synapse.util.caches.lrucache import LruCache from collections import namedtuple +from . import caches_by_name, cache_counter import threading import logging @@ -42,6 +43,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() + caches_by_name[name] = self.cache def check_thread(self): expected_thread = self.thread @@ -54,25 +56,21 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): - try: - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - # cache_counter.inc_hits(self.name) + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + cache_counter.inc_hits(self.name) - if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) - else: - return DictionaryEntry(entry.full, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) - # cache_counter.inc_misses(self.name) - return DictionaryEntry(False, {}) - except: - logger.exception("get failed") - raise + cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) def invalidate(self, key): self.check_thread() @@ -88,18 +86,14 @@ class DictionaryCache(object): self.cache.clear() def update(self, sequence, key, value, full=False): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - if full: - self._insert(key, value) - else: - self._update_or_insert(key, value) - except: - logger.exception("update failed") - raise + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) def _update_or_insert(self, key, value): entry = self.cache.setdefault(key, DictionaryEntry(False, {})) From f7e2f981ea1feb8461a5bddd9378bd5084833fc0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 16:01:10 +0100 Subject: [PATCH 31/59] Use list comprehension instead of filter --- synapse/handlers/message.py | 13 +++---------- synapse/handlers/sync.py | 13 +++---------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b941312ef..2c4af8dc9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -177,17 +177,10 @@ class MessageHandler(BaseHandler): return True - event_and_state = filter( - allowed, - [ - (e, event_id_to_state[e.event_id]) - for e in events - ] - ) - defer.returnValue([ - ev - for ev, _ in event_and_state + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) ]) @defer.inlineCallbacks diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d960078e7..ec8d78ba8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -334,17 +334,10 @@ class SyncHandler(BaseHandler): return True - event_and_state = filter( - allowed, - [ - (e, event_id_to_state[e.event_id]) - for e in events - ] - ) - defer.returnValue([ - ev - for ev, _ in event_and_state + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) ]) @defer.inlineCallbacks From a7eeb34c64d828539dd6799f2347371a8eabae73 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 16:02:05 +0100 Subject: [PATCH 32/59] Simplify staggered deferred lists --- synapse/handlers/message.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2c4af8dc9..8a9e6cf6c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -406,9 +406,9 @@ class MessageHandler(BaseHandler): # Only do N rooms at once n = 5 d_list = [handle_room(e) for e in room_list] - for ds in [d_list[i:i + n] for i in range(0, len(d_list), n)]: + for i in range(0, len(d_list), n): yield defer.gatherResults( - ds, + d_list[i:i + n], consumeErrors=True ).addErrback(unwrapFirstError) From cfa62007a392a2d6da818a45999c71da44a6da12 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 16:42:46 +0100 Subject: [PATCH 33/59] Docstring --- synapse/util/caches/descriptors.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 408c3b5e6..83bfec2f0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -341,6 +341,33 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + """Creates a descriptor that wraps a function in a `CacheListDescriptor`. + + Used to do batch lookups for an already created cache. A single argument + is specified as a list that is iterated through to lookup keys in the + original cache. A new list consisting of the keys that weren't in the cache + get passed to the original function, the result of which is stored in the + cache. + + Args: + cache (Cache): The underlying cache to use. + list_name (str): The name of the argument that is the list to use to + do batch lookups in the cache. + num_args (int): Number of arguments to use as the key in the cache. + inlineCallbacks (bool): Should the function be wrapped in an + `defer.inlineCallbacks`? + + Example: + + class Example(object): + @cached(num_args=2) + def do_something(self, first_arg): + ... + + @cachedList(do_something.cache, list_name="second_args", num_args=2) + def batch_do_something(self, first_arg, second_args): + ... + """ return lambda orig: CacheListDescriptor( orig, cache=cache, From 2eb91e6694f43854c28fe33326cb2e8be4ac69c5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 12 Aug 2015 16:53:30 +0100 Subject: [PATCH 34/59] enable registration in the demo servers --- demo/start.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/demo/start.sh b/demo/start.sh index b5dea5e17..572dbfab0 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -23,7 +23,6 @@ for port in 8080 8081 8082; do #rm $DIR/etc/$port.config python -m synapse.app.homeserver \ --generate-config \ - --enable_registration \ -H "localhost:$https_port" \ --config-path "$DIR/etc/$port.config" \ @@ -36,6 +35,8 @@ for port in 8080 8081 8082; do fi fi + perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config + python -m synapse.app.homeserver \ --config-path "$DIR/etc/$port.config" \ -D \ From 7b0e7970800df8cedf7966e6fb3837ee233d9ea4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 17:05:24 +0100 Subject: [PATCH 35/59] Fix _filter_events_for_client --- synapse/handlers/message.py | 4 +--- synapse/handlers/sync.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a9e6cf6c..29e81085d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -145,9 +145,7 @@ class MessageHandler(BaseHandler): ) ) - def allowed(event_and_state): - event, state = event_and_state - + def allowed(event, state): if event.type == EventTypes.RoomHistoryVisibility: return True diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ec8d78ba8..7206ae23d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -302,9 +302,7 @@ class SyncHandler(BaseHandler): ) ) - def allowed(event_and_state): - event, state = event_and_state - + def allowed(event, state): if event.type == EventTypes.RoomHistoryVisibility: return True From df361d08f7b778036cd0def86289d511267af849 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 17:06:21 +0100 Subject: [PATCH 36/59] Split _get_state_for_group_from_cache into two --- synapse/storage/state.py | 85 +++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 79c3b82d9..129384236 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -287,42 +287,39 @@ class StateStore(SQLBaseStore): f, ) - def _get_state_for_group_from_cache(self, group, types=None): + def _get_some_state_from_cache(self, group, types): """Checks if group is in cache. See `_get_state_for_groups` Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). `missing_types` is the list of types that aren't in the cache for that - group, or None if `types` is None. `got_all` is a bool indicating if - we successfully retrieved all requests state from the cache, if False - we need to query the DB for the missing state. + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. """ is_all, state_dict = self._state_group_cache.get(group) type_to_key = {} missing_types = set() - if types is not None: - for typ, state_key in types: - if state_key is None: - type_to_key[typ] = None + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: missing_types.add((typ, state_key)) - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - - if (typ, state_key) not in state_dict: - missing_types.add((typ, state_key)) - - if is_all: - missing_types = set() - if types is None: - return state_dict, set(), True sentinel = object() def include(typ, state_key): - if types is None: - return True - valid_state_keys = type_to_key.get(typ, sentinel) if valid_state_keys is sentinel: return False @@ -340,6 +337,19 @@ class StateStore(SQLBaseStore): if include(k[0], k[1]) }, missing_types, got_all + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): """Given list of groups returns dict of group -> list of state events @@ -349,18 +359,29 @@ class StateStore(SQLBaseStore): """ results = {} missing_groups_and_types = [] - for group in set(groups): - state_dict, missing_types, got_all = self._get_state_for_group_from_cache( - group, types - ) + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) - results[group] = state_dict + results[group] = state_dict - if not got_all: - missing_groups_and_types.append(( - group, - missing_types if types else None - )) + if not got_all: + missing_groups_and_types.append(( + group, + missing_types + )) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group + ) + + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, None)) if not missing_groups_and_types: defer.returnValue({ From 101ee3fd0022162da2ea2dc3f1b1846a08e21f9e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 17:08:05 +0100 Subject: [PATCH 37/59] Better variable name --- synapse/storage/state.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 129384236..a3fc859b0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -385,12 +385,12 @@ class StateStore(SQLBaseStore): if not missing_groups_and_types: defer.returnValue({ - k: { - key: ev - for key, ev in state.items() - if ev + group: { + type_tuple: event + for type_tuple, event in state.items() + if event } - for k, state in results.items() + for group, state in results.items() }) # Okay, so we have some missing_types, lets fetch them. From c10ac7806e2cde04f344b641d01efa46d34f9985 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 17:16:30 +0100 Subject: [PATCH 38/59] Explain why we're prefilling dict with Nones --- synapse/storage/state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a3fc859b0..57e334cc3 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -412,6 +412,10 @@ class StateStore(SQLBaseStore): for group, state_ids in group_state_dict.items(): if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. state_dict = { key: None for key in types From 0fbed2a8fac2c04ac4f46645aa9757fdca8b7cc6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2015 17:22:54 +0100 Subject: [PATCH 39/59] Comment --- synapse/storage/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 57e334cc3..abc5b6643 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -410,6 +410,8 @@ class StateStore(SQLBaseStore): for e in state_events } + # Now we want to update the cache with all the things we fetched + # from the database. for group, state_ids in group_state_dict.items(): if types: # We delibrately put key -> None mappings into the cache to @@ -433,7 +435,6 @@ class StateStore(SQLBaseStore): for e in evs }) - # Update the cache self._state_group_cache.update( cache_seq_num, key=group, From 21ac8be5f7be79cfffb32ad9fe1bba9515d9fd3e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 12 Aug 2015 17:25:13 +0100 Subject: [PATCH 40/59] Depend on Twisted>=15.1 rather than pining to a particular version --- setup.py | 2 +- synapse/python_dependencies.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c790484c7..16ccc0f1b 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ setup( description="Reference Synapse Home Server", install_requires=dependencies['requirements'](include_conditional=True).keys(), setup_requires=[ - "Twisted==15.2.1", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 + "Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "setuptools_trial", "mock" ], diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 534c4a669..fa06480ad 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil>=0.0.7": ["syutil>=0.0.7"], - "Twisted==15.2.1": ["twisted==15.2.1"], + "Twisted>=15.1.0": ["twisted>=15.1.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], From ba5d34a83274127a2c0226059778d226355bdb6c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 11:38:59 +0100 Subject: [PATCH 41/59] Add some metrics about the reactor --- synapse/metrics/__init__.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 9233ea3da..e014b415f 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -18,8 +18,12 @@ from __future__ import absolute_import import logging from resource import getrusage, getpagesize, RUSAGE_SELF +import functools import os import stat +import time + +from twisted.internet import reactor from .metric import ( CounterMetric, CallbackMetric, DistributionMetric, CacheMetric @@ -144,3 +148,28 @@ def _process_fds(): return counts get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) + +reactor_metrics = get_metrics_for("reactor") +tick_time = reactor_metrics.register_distribution("tick_time") +pending_calls_metric = reactor_metrics.register_distribution("pending_calls") + + +def runUntilCurrentTimer(func): + + @functools.wraps(func) + def f(*args, **kwargs): + start = time.time() * 1000 + pending_calls = len(reactor.getDelayedCalls()) + ret = func(*args, **kwargs) + end = time.time() * 1000 + tick_time.inc_by(end - start) + pending_calls_metric.inc_by(pending_calls) + return ret + + return f + + +if hasattr(reactor, "runUntilCurrent"): + # runUntilCurrent is called when we have pending calls. It is called once + # per iteratation after fd polling. + reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) From a6c27de1aa283a9d7a347db53b08367c53a15a24 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 11:41:57 +0100 Subject: [PATCH 42/59] Don't time getDelayedCalls --- synapse/metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index e014b415f..0be977299 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -158,8 +158,8 @@ def runUntilCurrentTimer(func): @functools.wraps(func) def f(*args, **kwargs): - start = time.time() * 1000 pending_calls = len(reactor.getDelayedCalls()) + start = time.time() * 1000 ret = func(*args, **kwargs) end = time.time() * 1000 tick_time.inc_by(end - start) From adbd720fab5f6ee42c4e14b06eb6f385bae14dc6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 11:47:38 +0100 Subject: [PATCH 43/59] PEP8 --- synapse/http/matrixfederationclient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4d74bd5d7..854e17a47 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -142,7 +142,6 @@ class MatrixFederationHttpClient(object): producer ) - return self.clock.time_bound_deferred( request_deferred, time_out=timeout/1000. if timeout else 60, From 7eb4d626bada90834f8cfe464e4424ed8991e590 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 13 Aug 2015 13:12:33 +0100 Subject: [PATCH 44/59] Add max-line-length to the flake8 section of setup.cfg --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 888ad6ed4..abb649958 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,3 +16,6 @@ ignore = docs/* pylint.cfg tox.ini + +[flake8] +max-line-length = 90 From 57877b01d7dad613f4e1dc3fb99983320ea50d40 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 17:00:17 +0100 Subject: [PATCH 45/59] Replace list comprehension --- synapse/storage/state.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index abc5b6643..1f5e62fa8 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -426,14 +426,9 @@ class StateStore(SQLBaseStore): else: state_dict = results[group] - evs = [ - state_events[e_id] for e_id in state_ids - if e_id in state_events # This can happen if event is rejected. - ] - state_dict.update({ - (e.type, e.state_key): e - for e in evs - }) + for event_id in state_ids: + state_event = state_events[event_id] + state_dict[(state_event.type, state_event.state_key)] = state_event self._state_group_cache.update( cache_seq_num, From 2bb2c025711f8ff8ea044a203059ca3ea6b94749 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 17:11:30 +0100 Subject: [PATCH 46/59] Remove some vertical space --- synapse/storage/state.py | 45 ++++++++-------------------------------- 1 file changed, 9 insertions(+), 36 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 1f5e62fa8..185f88fd7 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -60,7 +60,6 @@ class StateStore(SQLBaseStore): ) groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) defer.returnValue({ @@ -201,10 +200,7 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) - results[group] = [ - r[0] - for r in txn.fetchall() - ] + results[group] = [r[0] for r in txn.fetchall()] return results @@ -235,20 +231,14 @@ class StateStore(SQLBaseStore): ) groups = set(event_to_groups.values()) - - group_to_state = yield self._get_state_for_groups( - groups, types - ) + group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] for event_id, group in event_to_groups.items() } - defer.returnValue({ - event: event_to_state[event] - for event in event_ids - }) + defer.returnValue({event: event_to_state[event] for event in event_ids}) @cached(num_args=2, lru=True, max_entries=100000) def _get_state_group_for_event(self, room_id, event_id): @@ -282,10 +272,7 @@ class StateStore(SQLBaseStore): return results - return self.runInteraction( - "_get_state_group_for_events", - f, - ) + return self.runInteraction("_get_state_group_for_events", f) def _get_some_state_from_cache(self, group, types): """Checks if group is in cache. See `_get_state_for_groups` @@ -332,8 +319,7 @@ class StateStore(SQLBaseStore): got_all = not (missing_types or types is None) return { - k: v - for k, v in state_dict.items() + k: v for k, v in state_dict.items() if include(k[0], k[1]) }, missing_types, got_all @@ -364,20 +350,15 @@ class StateStore(SQLBaseStore): state_dict, missing_types, got_all = self._get_some_state_from_cache( group, types ) - results[group] = state_dict if not got_all: - missing_groups_and_types.append(( - group, - missing_types - )) + missing_groups_and_types.append((group, missing_types)) else: for group in set(groups): state_dict, got_all = self._get_all_state_from_cache( group ) - results[group] = state_dict if not got_all: @@ -405,10 +386,7 @@ class StateStore(SQLBaseStore): get_prev_content=False ) - state_events = { - e.event_id: e - for e in state_events - } + state_events = {e.event_id: e for e in state_events} # Now we want to update the cache with all the things we fetched # from the database. @@ -418,10 +396,7 @@ class StateStore(SQLBaseStore): # cache absence of the key, on the assumption that if we've # explicitly asked for some types then we will probably ask # for them again. - state_dict = { - key: None - for key in types - } + state_dict = {key: None for key in types} state_dict.update(results[group]) else: state_dict = results[group] @@ -438,9 +413,7 @@ class StateStore(SQLBaseStore): ) results[group].update({ - key: value - for key, value in state_dict.items() - if value + key: value for key, value in state_dict.items() if value }) defer.returnValue(results) From 9f7f228ec2e2948c69ca3910d27fffdd2c2fea50 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2015 17:20:59 +0100 Subject: [PATCH 47/59] Remove pointless map --- synapse/handlers/federation.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2bfd0a40e..1e3dccf5a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -237,9 +237,7 @@ class FederationHandler(BaseHandler): ) ) - def redact_disallowed(event_and_state): - event, state = event_and_state - + def redact_disallowed(event, state): if not state: return event @@ -273,13 +271,11 @@ class FederationHandler(BaseHandler): return event - res = map(redact_disallowed, [ - (e, event_to_state[e.event_id]) + defer.returnValue([ + redact_disallowed(e, event_to_state[e.event_id]) for e in events ]) - defer.returnValue(res) - @log_function @defer.inlineCallbacks def backfill(self, dest, room_id, limit, extremities=[]): From 0cceb2ac92cde0a4289adfc6e9000c7b1c54bdae Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 13 Aug 2015 17:27:46 +0100 Subject: [PATCH 48/59] Add a few strategic new lines to break up the on_query_client_keys and on_claim_client_keys methods in federation_server.py --- synapse/federation/federation_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c32908ac2..725c6f3fa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -323,13 +323,16 @@ class FederationServer(FederationBase): else: for device_id in device_ids: query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) + defer.returnValue({"device_keys": json_result}) @defer.inlineCallbacks @@ -339,7 +342,9 @@ class FederationServer(FederationBase): for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -347,6 +352,7 @@ class FederationServer(FederationBase): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } + defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks From f9d4da7f4502aeefe7a5a6a9ec0d2682458e7834 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Aug 2015 09:39:12 +0100 Subject: [PATCH 49/59] Fix bug where we were leaking None into state event lists --- synapse/storage/state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 185f88fd7..ecb62e6df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -412,9 +412,10 @@ class StateStore(SQLBaseStore): full=(types is None), ) - results[group].update({ + # We replace here to remove all the entries with None values. + results[group] = { key: value for key, value in state_dict.items() if value - }) + } defer.returnValue(results) From 47abebfd6d0ea56d7ac7a565f359992fde323177 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Aug 2015 09:50:50 +0100 Subject: [PATCH 50/59] Add batched version of store.get_presence_state --- synapse/storage/presence.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4f91a2b87..f351b76a7 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from twisted.internet import defer @@ -36,6 +36,7 @@ class PresenceStore(SQLBaseStore): desc="has_presence_state", ) + @cached() def get_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -44,6 +45,23 @@ class PresenceStore(SQLBaseStore): desc="get_presence_state", ) + @cachedList(get_presence_state.cache, list_name="user_localparts") + def get_presence_states(self, user_localparts): + def f(txn): + results = {} + for user_localpart in user_localparts: + results[user_localpart] = self._simple_select_one_txn( + txn, + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["state", "status_msg", "mtime"], + desc="get_presence_state", + ) + + return results + + return self.runInteraction("get_presence_states", f) + def set_presence_state(self, user_localpart, new_state): return self._simple_update_one( table="presence", From 1a9510bb84d79f6ff78d32390195bc97ed9a439e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Aug 2015 10:40:23 +0100 Subject: [PATCH 51/59] Implement a batched presence_handler.get_state and use it --- synapse/handlers/message.py | 18 ++++------- synapse/handlers/presence.py | 63 ++++++++++++++++++++++++++++++++++++ synapse/storage/presence.py | 6 ++-- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 29e81085d..f12465fa2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -460,20 +460,14 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): - presence_defs = yield defer.DeferredList( - [ - presence_handler.get_state( - target_user=UserID.from_string(m.user_id), - auth_user=auth_user, - as_event=True, - check_auth=False, - ) - for m in room_members - ], - consumeErrors=True, + states = yield presence_handler.get_states( + target_users=[UserID.from_string(m.user_id) for m in room_members], + auth_user=auth_user, + as_event=True, + check_auth=False, ) - defer.returnValue([p for success, p in presence_defs if success]) + defer.returnValue(states.values()) receipts_handler = self.hs.get_handlers().receipts_handler diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 341a516da..33d76efe0 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -232,6 +232,69 @@ class PresenceHandler(BaseHandler): else: defer.returnValue(state) + @defer.inlineCallbacks + def get_states(self, target_users, auth_user, as_event=False, check_auth=True): + try: + local_users, remote_users = partitionbool( + target_users, + lambda u: self.hs.is_mine(u) + ) + + if check_auth: + for u in local_users: + visible = yield self.is_presence_visible( + observer_user=auth_user, + observed_user=u + ) + + if not visible: + raise SynapseError(404, "Presence information not visible") + + results = {} + if local_users: + for u in local_users: + if u in self._user_cachemap: + results[u] = self._user_cachemap[u].get_state() + + local_to_user = {u.localpart: u for u in local_users} + + states = yield self.store.get_presence_states( + [u.localpart for u in local_users if u not in results] + ) + + for local_part, state in states.items(): + res = {"presence": state["state"]} + if "status_msg" in state and state["status_msg"]: + res["status_msg"] = state["status_msg"] + results[local_to_user[local_part]] = res + + for u in remote_users: + # TODO(paul): Have remote server send us permissions set + results[u] = self._get_or_offline_usercache(u).get_state() + + for state in results.values(): + if "last_active" in state: + state["last_active_ago"] = int( + self.clock.time_msec() - state.pop("last_active") + ) + + if as_event: + for user, state in results.items(): + content = state + content["user_id"] = user.to_string() + + if "last_active" in content: + content["last_active_ago"] = int( + self._clock.time_msec() - content.pop("last_active") + ) + + results[user] = {"type": "m.presence", "content": content} + except: + logger.exception(":(") + raise + + defer.returnValue(results) + @defer.inlineCallbacks @log_function def set_state(self, target_user, auth_user, state): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index f351b76a7..15d98198e 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -50,13 +50,15 @@ class PresenceStore(SQLBaseStore): def f(txn): results = {} for user_localpart in user_localparts: - results[user_localpart] = self._simple_select_one_txn( + res = self._simple_select_one_txn( txn, table="presence", keyvalues={"user_id": user_localpart}, retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + allow_none=True, ) + if res: + results[user_localpart] = res return results From 2d97e65558f37fa0fbdd8d06545c32b410f1b5ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Aug 2015 10:46:55 +0100 Subject: [PATCH 52/59] Remember to invalidate caches --- synapse/storage/presence.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 15d98198e..9b136f311 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -21,12 +21,15 @@ from twisted.internet import defer class PresenceStore(SQLBaseStore): def create_presence(self, user_localpart): - return self._simple_insert( + res = self._simple_insert( table="presence", values={"user_id": user_localpart}, desc="create_presence", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def has_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -65,7 +68,7 @@ class PresenceStore(SQLBaseStore): return self.runInteraction("get_presence_states", f) def set_presence_state(self, user_localpart, new_state): - return self._simple_update_one( + res = self._simple_update_one( table="presence", keyvalues={"user_id": user_localpart}, updatevalues={"state": new_state["state"], @@ -74,6 +77,9 @@ class PresenceStore(SQLBaseStore): desc="set_presence_state", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( table="presence_allow_inbound", From f72ed6c6a353bad4a54cb695eae12d39fd41ad24 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 10:29:49 +0100 Subject: [PATCH 53/59] Remove debug try/catch --- synapse/handlers/presence.py | 98 +++++++++++++++++------------------- 1 file changed, 47 insertions(+), 51 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 33d76efe0..b7664e30f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -234,64 +234,60 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - try: - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + local_users, remote_users = partitionbool( + target_users, + lambda u: self.hs.is_mine(u) + ) - if check_auth: - for u in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=u - ) - - if not visible: - raise SynapseError(404, "Presence information not visible") - - results = {} - if local_users: - for u in local_users: - if u in self._user_cachemap: - results[u] = self._user_cachemap[u].get_state() - - local_to_user = {u.localpart: u for u in local_users} - - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] + if check_auth: + for u in local_users: + visible = yield self.is_presence_visible( + observer_user=auth_user, + observed_user=u ) - for local_part, state in states.items(): - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res + if not visible: + raise SynapseError(404, "Presence information not visible") - for u in remote_users: - # TODO(paul): Have remote server send us permissions set - results[u] = self._get_or_offline_usercache(u).get_state() + results = {} + if local_users: + for u in local_users: + if u in self._user_cachemap: + results[u] = self._user_cachemap[u].get_state() - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + local_to_user = {u.localpart: u for u in local_users} + + states = yield self.store.get_presence_states( + [u.localpart for u in local_users if u not in results] + ) + + for local_part, state in states.items(): + res = {"presence": state["state"]} + if "status_msg" in state and state["status_msg"]: + res["status_msg"] = state["status_msg"] + results[local_to_user[local_part]] = res + + for u in remote_users: + # TODO(paul): Have remote server send us permissions set + results[u] = self._get_or_offline_usercache(u).get_state() + + for state in results.values(): + if "last_active" in state: + state["last_active_ago"] = int( + self.clock.time_msec() - state.pop("last_active") + ) + + if as_event: + for user, state in results.items(): + content = state + content["user_id"] = user.to_string() + + if "last_active" in content: + content["last_active_ago"] = int( + self._clock.time_msec() - content.pop("last_active") ) - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() - - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) - - results[user] = {"type": "m.presence", "content": content} - except: - logger.exception(":(") - raise + results[user] = {"type": "m.presence", "content": content} defer.returnValue(results) From 776ee6d92b8672c36723b0b8dc9ae3467f34c08f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 10:30:07 +0100 Subject: [PATCH 54/59] Doc strings --- synapse/handlers/presence.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b7664e30f..1177cbe51 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def get_state(self, target_user, auth_user, as_event=False, check_auth=True): + """Get the current presence state of the given user. + + Args: + target_user (UserID): The user whose presence we want + auth_user (UserID): The user requesting the presence, used for + checking if said user is allowed to see the persence of the + `target_user` + as_event (bool): Format the return as an event or not? + check_auth (bool): Perform the auth checks or not? + + Returns: + dict: The presence state of the `target_user`, whose format depends + on the `as_event` argument. + """ if self.hs.is_mine(target_user): if check_auth: visible = yield self.is_presence_visible( @@ -234,6 +248,20 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def get_states(self, target_users, auth_user, as_event=False, check_auth=True): + """A batched version of the `get_state` method that accepts a list of + `target_users` + + Args: + target_users (list): The list of UserID's whose presence we want + auth_user (UserID): The user requesting the presence, used for + checking if said user is allowed to see the persence of the + `target_users` + as_event (bool): Format the return as an event or not? + check_auth (bool): Perform the auth checks or not? + + Returns: + dict: A mapping from user -> presence_state + """ local_users, remote_users = partitionbool( target_users, lambda u: self.hs.is_mine(u) From 83eb627b5a50ef7cd8803f6c6fae9c5f5271bcb1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 10:33:11 +0100 Subject: [PATCH 55/59] More helpful variable names --- synapse/handlers/presence.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1177cbe51..2b103b48b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -268,10 +268,10 @@ class PresenceHandler(BaseHandler): ) if check_auth: - for u in local_users: + for user in local_users: visible = yield self.is_presence_visible( observer_user=auth_user, - observed_user=u + observed_user=user ) if not visible: @@ -279,9 +279,9 @@ class PresenceHandler(BaseHandler): results = {} if local_users: - for u in local_users: - if u in self._user_cachemap: - results[u] = self._user_cachemap[u].get_state() + for user in local_users: + if user in self._user_cachemap: + results[user] = self._user_cachemap[user].get_state() local_to_user = {u.localpart: u for u in local_users} @@ -295,9 +295,9 @@ class PresenceHandler(BaseHandler): res["status_msg"] = state["status_msg"] results[local_to_user[local_part]] = res - for u in remote_users: + for user in remote_users: # TODO(paul): Have remote server send us permissions set - results[u] = self._get_or_offline_usercache(u).get_state() + results[user] = self._get_or_offline_usercache(user).get_state() for state in results.values(): if "last_active" in state: From 85d0bc3bdce038726be4daad6560e1a2a982c9fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 11:00:38 +0100 Subject: [PATCH 56/59] Reduce cache size from obscenely large to quite large --- synapse/storage/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ecb62e6df..ab3ad5a07 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -240,7 +240,7 @@ class StateStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) - @cached(num_args=2, lru=True, max_entries=100000) + @cached(num_args=2, lru=True, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", From 8e254862f49e5183b0a0e9c4a41ae1f1c3477418 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 11:11:33 +0100 Subject: [PATCH 57/59] Don't assume @cachedList function returns keys for everything --- synapse/util/caches/descriptors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 83bfec2f0..362944bc5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -293,7 +293,7 @@ class CacheListDescriptor(object): # we can insert the new deferred into the cache. for arg in missing: observer = ret_d.observe() - observer.addCallback(lambda r, arg: r[arg], arg) + observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) From e55291ce5ec0c65c54363ef9366c2357df2ee44f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 11:17:37 +0100 Subject: [PATCH 58/59] None check --- synapse/handlers/presence.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2b103b48b..748432959 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -290,6 +290,8 @@ class PresenceHandler(BaseHandler): ) for local_part, state in states.items(): + if stat is None: + continue res = {"presence": state["state"]} if "status_msg" in state and state["status_msg"]: res["status_msg"] = state["status_msg"] From 0d4abf77773ca0af73422aff1a35bb73c9235e1f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Aug 2015 11:19:08 +0100 Subject: [PATCH 59/59] Typo --- synapse/handlers/presence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 748432959..e91e81831 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -290,7 +290,7 @@ class PresenceHandler(BaseHandler): ) for local_part, state in states.items(): - if stat is None: + if state is None: continue res = {"presence": state["state"]} if "status_msg" in state and state["status_msg"]: