From caddadfc5ac61d1c91fbaf29bf3298f90a140560 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:04:57 +0000 Subject: [PATCH 01/45] Change device_inbox stream index to include user This makes fetching the nost recently changed users much tricker, and brings it in line with e.g. presence_stream indices. --- synapse/storage/deviceinbox.py | 38 ++++++++++++++++++- synapse/storage/prepare_database.py | 2 +- .../storage/schema/delta/40/device_inbox.sql | 20 ++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 synapse/storage/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c..b71ac3ae3 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,21 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + finally: + conn.set_session(autocommit=False) + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502..b357f22be 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 000000000..ce58fe208 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); From 5a32e9273ec9759caf09d5b8204dd29e7a007b97 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:11:27 +0000 Subject: [PATCH 02/45] Don't disable autocommit --- synapse/storage/deviceinbox.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index b71ac3ae3..b0ab70baf 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -388,14 +388,10 @@ class DeviceInboxStore(BackgroundUpdateStore): @defer.inlineCallbacks def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "DROP INDEX IF EXISTS device_inbox_stream_id" - ) - finally: - conn.set_session(autocommit=False) + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) yield self.runWithConnection(reindex_txn) From ab655dca339f8d4168079cc2b4529dc50265fc83 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 15:15:25 +0000 Subject: [PATCH 03/45] Explicitly close the cursor --- synapse/storage/deviceinbox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index b0ab70baf..bde3b5cbb 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -392,6 +392,7 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute( "DROP INDEX IF EXISTS device_inbox_stream_id" ) + txn.close() yield self.runWithConnection(reindex_txn) From edd6cdfc9a1cf180871657baaf2aa6da5845f25a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 17:21:41 +0000 Subject: [PATCH 04/45] Restore default bind address --- synapse/app/appservice.py | 12 ++---------- synapse/app/client_reader.py | 12 ++---------- synapse/app/federation_reader.py | 12 ++---------- synapse/app/federation_sender.py | 12 ++---------- synapse/app/homeserver.py | 12 ++---------- synapse/app/media_repository.py | 12 ++---------- synapse/app/pusher.py | 12 ++---------- synapse/app/synchrotron.py | 12 ++---------- synapse/config/server.py | 17 +++++++++++++---- 9 files changed, 29 insertions(+), 84 deletions(-) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index c1379fdd7..190093005 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -76,8 +76,7 @@ class AppserviceServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -87,9 +86,6 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -109,11 +105,7 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index b5e1d659e..4d081eccd 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -90,8 +90,7 @@ class ClientReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -110,9 +109,6 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -132,11 +128,7 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index c6810b83d..90a481675 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,8 +86,7 @@ class FederationReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -101,9 +100,6 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -123,11 +119,7 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 23aae8a09..ec06620ef 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -82,8 +82,7 @@ class FederationSenderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -93,9 +92,6 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -115,11 +111,7 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6c69ccd7e..e0b87468f 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -107,8 +107,7 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] tls = listener_config.get("tls", False) site_tag = listener_config.get("tag", port) @@ -175,9 +174,6 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) - if bind_address is not None: - bind_addresses.append(bind_address) - if tls: for address in bind_addresses: reactor.listenSSL( @@ -212,11 +208,7 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index a47283e52..ef17b158a 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -87,8 +87,7 @@ class MediaRepositoryServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -107,9 +106,6 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -129,11 +125,7 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 57e097fa1..073f2c248 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -121,8 +121,7 @@ class PusherServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -132,9 +131,6 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -154,11 +150,7 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 439daaa60..4dfc2dc64 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -289,8 +289,7 @@ class SynchrotronServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -312,9 +311,6 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -334,11 +330,7 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/config/server.py b/synapse/config/server.py index 5e6b2a68a..59687ee39 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -42,6 +42,15 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) + for listener in self.listeners: + bind_address = listener.get("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') + self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) bind_port = config.get("bind_port") @@ -54,7 +63,7 @@ class ServerConfig(Config): self.listeners.append({ "port": bind_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": True, "type": "http", "resources": [ @@ -73,7 +82,7 @@ class ServerConfig(Config): if unsecure_port: self.listeners.append({ "port": unsecure_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": False, "type": "http", "resources": [ @@ -92,7 +101,7 @@ class ServerConfig(Config): if manhole: self.listeners.append({ "port": manhole, - "bind_address": "127.0.0.1", + "bind_addresses": ["127.0.0.1"], "type": "manhole", }) @@ -100,7 +109,7 @@ class ServerConfig(Config): if metrics_port: self.listeners.append({ "port": metrics_port, - "bind_address": config.get("metrics_bind_host", "127.0.0.1"), + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], "tls": False, "type": "http", "resources": [ From b1dfd202928174ca5b377196b813e6ce51fe0999 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 17:23:18 +0000 Subject: [PATCH 05/45] Pop bind_address --- synapse/config/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/config/server.py b/synapse/config/server.py index 59687ee39..1f9999d57 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -43,7 +43,7 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) for listener in self.listeners: - bind_address = listener.get("bind_address", None) + bind_address = listener.pop("bind_address", None) bind_addresses = listener.setdefault("bind_addresses", []) if bind_address: From 7e6c2937c327c76e32fb663d4a94072a0492c338 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:16:54 +0000 Subject: [PATCH 06/45] Split out static auth methods from Auth object --- synapse/api/auth.py | 1168 +++++++++++++++++++++++-------------------- 1 file changed, 622 insertions(+), 546 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f93e45a74..5e2b89c32 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -27,7 +27,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -43,6 +42,622 @@ AuthEventTypes = ( GUEST_DEVICE_ID = "guest_device" +class Auther(object): + @staticmethod + def check(event, auth_events, do_sig_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + Auther.check_size_limits(event) + + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True + + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not Auther.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True + + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) + + if event.type == EventTypes.Member: + allowed = Auther.is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed + + Auther.check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = Auther._get_user_power_level(event.user_id, auth_events) + invite_level = Auther._get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + + Auther._can_send_event(event, auth_events) + + if event.type == EventTypes.PowerLevels: + Auther._check_power_levels(event, auth_events) + + if event.type == EventTypes.Redaction: + Auther.check_redaction(event, auth_events) + + logger.debug("Allowing! %s", event) + + @staticmethod + def check_size_limits(event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + @staticmethod + def can_federate(event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + + @staticmethod + def is_membership_change_allowed(event, auth_events): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (EventTypes.Create, "", ) + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + + target_user_id = event.state_key + + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) + if creating_domain != target_domain: + if not Auther.can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # get info about the caller + key = (EventTypes.Member, event.user_id, ) + caller = auth_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE + + # get info about the target + key = (EventTypes.Member, target_user_id, ) + target = auth_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN + target_banned = target and target.membership == Membership.BAN + + key = (EventTypes.JoinRules, "", ) + join_rule_event = auth_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: + join_rule = JoinRules.INVITE + + user_level = Auther._get_user_power_level(event.user_id, auth_events) + target_level = Auther._get_user_power_level( + target_user_id, auth_events + ) + + # FIXME (erikj): What should we do here as the default? + ban_level = Auther._get_named_level(auth_events, "ban", 50) + + logger.debug( + "is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_banned": target_banned, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not Auther._verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + return True + + if Membership.JOIN != membership: + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + + if not caller_in_room: # caller isn't joined + raise AuthError( + 403, + "%s not in room %s." % (event.user_id, event.room_id,) + ) + + if Membership.INVITE == membership: + # TODO (erikj): We should probably handle this more intelligently + # PRIVATE join rules. + + # Invites are valid iff caller is in the room and target isn't. + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + target_user_id) + else: + invite_level = Auther._get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif target_banned: + raise AuthError(403, "You are banned from this room") + elif join_rule == JoinRules.PUBLIC: + pass + elif join_rule == JoinRules.INVITE: + if not caller_in_room and not caller_invited: + raise AuthError(403, "You are not invited to this room.") + else: + # TODO (erikj): may_join list + # TODO (erikj): private rooms + raise AuthError(403, "You are not allowed to join this room") + elif Membership.LEAVE == membership: + # TODO (erikj): Implement kicks. + if target_banned and user_level < ban_level: + raise AuthError( + 403, "You cannot unban user &s." % (target_user_id,) + ) + elif target_user_id != event.user_id: + kick_level = Auther._get_named_level(auth_events, "kick", 50) + + if user_level < kick_level or user_level <= target_level: + raise AuthError( + 403, "You cannot kick user %s." % target_user_id + ) + elif Membership.BAN == membership: + if user_level < ban_level or user_level <= target_level: + raise AuthError(403, "You don't have permission to ban") + else: + raise AuthError(500, "Unknown membership %s" % membership) + + return True + + @staticmethod + def check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return Auther._check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + @staticmethod + def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + @staticmethod + def _get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + @staticmethod + def _can_send_event(event, auth_events): + send_level = Auther._get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + @staticmethod + def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + redact_level = Auther._get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + @staticmethod + def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = Auther._get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + @staticmethod + def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + @staticmethod + def _get_user_power_level(user_id, auth_events): + power_level_event = Auther._get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + @staticmethod + def _get_named_level(auth_events, name, default): + power_level_event = Auther._get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + @staticmethod + def _verify_third_party_invite(event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if invite_event.sender != event.sender: + return False + + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in Auther.get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + @staticmethod + def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys + + class Auth(object): """ FIXME: This class contains a mix of functions for authenticating users @@ -78,130 +693,7 @@ class Auth(object): True if the auth checks pass. """ with Measure(self.clock, "auth.check"): - self.check_size_limits(event) - - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) - - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True - - creation_event = auth_events.get((EventTypes.Create, ""), None) - - if not creation_event: - raise SynapseError( - 403, - "Room %r does not exist" % (event.room_id,) - ) - - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True - - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) - - if event.type == EventTypes.Member: - allowed = self.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - self.check_event_sender_in_room(event, auth_events) - - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = self._get_user_power_level(event.user_id, auth_events) - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) - ) - else: - return True - - self._can_send_event(event, auth_events) - - if event.type == EventTypes.PowerLevels: - self._check_power_levels(event, auth_events) - - if event.type == EventTypes.Redaction: - self.check_redaction(event, auth_events) - - logger.debug("Allowing! %s", event) + Auther.check(event, auth_events, do_sig_check=do_sig_check) def check_size_limits(self, event): def too_big(field): @@ -300,16 +792,6 @@ class Auth(object): ) defer.returnValue(ret) - def check_event_sender_in_room(self, event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return self._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: raise AuthError(403, "User %s not in room %s (%s)" % ( @@ -321,267 +803,8 @@ class Auth(object): return creation_event.content.get("m.federate", True) is True - @log_function - def is_membership_change_allowed(self, event, auth_events): - membership = event.content["membership"] - - # Check if this is the room creator joining: - if len(event.prev_events) == 1 and Membership.JOIN == membership: - # Get room creation event: - key = (EventTypes.Create, "", ) - create = auth_events.get(key) - if create and event.prev_events[0][0] == create.event_id: - if create.content["creator"] == event.state_key: - return True - - target_user_id = event.state_key - - creating_domain = get_domain_from_id(event.room_id) - target_domain = get_domain_from_id(target_user_id) - if creating_domain != target_domain: - if not self.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # get info about the caller - key = (EventTypes.Member, event.user_id, ) - caller = auth_events.get(key) - - caller_in_room = caller and caller.membership == Membership.JOIN - caller_invited = caller and caller.membership == Membership.INVITE - - # get info about the target - key = (EventTypes.Member, target_user_id, ) - target = auth_events.get(key) - - target_in_room = target and target.membership == Membership.JOIN - target_banned = target and target.membership == Membership.BAN - - key = (EventTypes.JoinRules, "", ) - join_rule_event = auth_events.get(key) - if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) - else: - join_rule = JoinRules.INVITE - - user_level = self._get_user_power_level(event.user_id, auth_events) - target_level = self._get_user_power_level( - target_user_id, auth_events - ) - - # FIXME (erikj): What should we do here as the default? - ban_level = self._get_named_level(auth_events, "ban", 50) - - logger.debug( - "is_membership_change_allowed: %s", - { - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": event.user_id, - } - ) - - if Membership.INVITE == membership and "third_party_invite" in event.content: - if not self._verify_third_party_invite(event, auth_events): - raise AuthError(403, "You are not invited to this room.") - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - return True - - if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): - return True - - if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) - - if Membership.INVITE == membership: - # TODO (erikj): We should probably handle this more intelligently - # PRIVATE join rules. - - # Invites are valid iff caller is in the room and target isn't. - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) - else: - invite_level = self._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, "You cannot invite user %s." % target_user_id - ) - elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP - if event.user_id != target_user_id: - raise AuthError(403, "Cannot force another user to join.") - elif target_banned: - raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: - pass - elif join_rule == JoinRules.INVITE: - if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") - else: - # TODO (erikj): may_join list - # TODO (erikj): private rooms - raise AuthError(403, "You are not allowed to join this room") - elif Membership.LEAVE == membership: - # TODO (erikj): Implement kicks. - if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) - ) - elif target_user_id != event.user_id: - kick_level = self._get_named_level(auth_events, "kick", 50) - - if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) - elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: - raise AuthError(403, "You don't have permission to ban") - else: - raise AuthError(500, "Unknown membership %s" % membership) - - return True - - def _verify_third_party_invite(self, event, auth_events): - """ - Validates that the invite event is authorized by a previous third-party invite. - - Checks that the public key, and keyserver, match those in the third party invite, - and that the invite event has a signature issued using that public key. - - Args: - event: The m.room.member join event being validated. - auth_events: All relevant previous context events which may be used - for authorization decisions. - - Return: - True if the event fulfills the expectations of a previous third party - invite event. - """ - if "third_party_invite" not in event.content: - return False - if "signed" not in event.content["third_party_invite"]: - return False - signed = event.content["third_party_invite"]["signed"] - for key in {"mxid", "token"}: - if key not in signed: - return False - - token = signed["token"] - - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) - if not invite_event: - return False - - if invite_event.sender != event.sender: - return False - - if event.user_id != invite_event.user_id: - return False - - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - - for public_key_object in self.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False - def get_public_keys(self, invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys - - def _get_power_level_event(self, auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) - - def _get_user_power_level(self, user_id, auth_events): - power_level_event = self._get_power_level_event(auth_events) - - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 - - def _get_named_level(self, auth_events, name, default): - power_level_event = self._get_power_level_event(auth_events) - - if not power_level_event: - return default - - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default + return Auther.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -975,54 +1198,7 @@ class Auth(object): defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @log_function - def _can_send_event(self, event, auth_events): - send_level = self._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = self._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True + return Auther._get_send_level(etype, state_key, auth_events) def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1037,107 +1213,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - user_level = self._get_user_power_level(event.user_id, auth_events) - - redact_level = self._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - def _check_power_levels(self, event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = self._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) + return Auther.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1167,10 +1243,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = self._get_send_level( + send_level = Auther._get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = self._get_user_power_level(user_id, auth_events) + user_level = Auther._get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( From bf5c9706d9053ffad05fc12eca71b8d441fa9306 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2017 10:32:52 +0000 Subject: [PATCH 07/45] Remove full_twisted_stacktraces option The debug 'full_twisted_stacktraces' flag caused synapse to rewrite twisted deferreds to always fire the callback on the next reactor tick. This was to force the deferred to always store the stacktraces on exceptions, and thus be more likely to have a full stacktrace when it reaches the final error handlers and gets printed to the logs. Dynamically rewriting things is generally bad, and in particular this change violates assumptions of various bits of Twisted. This wouldn't necessarily be so bad, but it turns out this option has been turned on on some production servers. Turning the option can cause e.g. #1778. For now, lets just entirely nuke this option. --- synapse/config/logger.py | 8 ----- synapse/util/debug.py | 71 ---------------------------------------- 2 files changed, 79 deletions(-) delete mode 100644 synapse/util/debug.py diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 63e69a7e0..77ded0ad2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -22,7 +22,6 @@ import yaml from string import Template import os import signal -from synapse.util.debug import debug_deferreds DEFAULT_LOG_CONFIG = Template(""" @@ -71,8 +70,6 @@ class LoggingConfig(Config): self.verbosity = config.get("verbose", 0) self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - if config.get("full_twisted_stacktraces"): - debug_deferreds() def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") @@ -88,11 +85,6 @@ class LoggingConfig(Config): # A yaml python logging config file log_config: "%(log_config)s" - - # Stop twisted from discarding the stack traces of exceptions in - # deferreds by waiting a reactor tick before running a deferred's - # callbacks. - # full_twisted_stacktraces: true """ % locals() def read_arguments(self, args): diff --git a/synapse/util/debug.py b/synapse/util/debug.py deleted file mode 100644 index dc49162e6..000000000 --- a/synapse/util/debug.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer, reactor -from functools import wraps -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext - - -def debug_deferreds(): - """Cause all deferreds to wait for a reactor tick before running their - callbacks. This increases the chance of getting a stack trace out of - a defer.inlineCallback since the code waiting on the deferred will get - a chance to add an errback before the deferred runs.""" - - # Helper method for retrieving and restoring the current logging context - # around a callback. - def with_logging_context(fn): - context = LoggingContext.current_context() - - def restore_context_callback(x): - with PreserveLoggingContext(context): - return fn(x) - - return restore_context_callback - - # We are going to modify the __init__ method of defer.Deferred so we - # need to get a copy of the old method so we can still call it. - old__init__ = defer.Deferred.__init__ - - # We need to create a deferred to bounce the callbacks through the reactor - # but we don't want to add a callback when we create that deferred so we - # we create a new type of deferred that uses the old __init__ method. - # This is safe as long as the old __init__ method doesn't invoke an - # __init__ using super. - class Bouncer(defer.Deferred): - __init__ = old__init__ - - # We'll add this as a callback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the callbacks of the - # original deferred. - def bounce_callback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.callback), x) - return bouncer - - # We'll add this as an errback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the errbacks of the - # original deferred. - def bounce_errback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.errback), x) - return bouncer - - @wraps(old__init__) - def new__init__(self, *args, **kargs): - old__init__(self, *args, **kargs) - self.addCallbacks(bounce_callback, bounce_errback) - - defer.Deferred.__init__ = new__init__ From ebf94aff8d8cf6a6ed187b2c8e6aaa69f3912a48 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Jan 2017 17:19:47 +0000 Subject: [PATCH 08/45] Fix spurious Unhandled Error log lines --- synapse/rest/client/transactions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 351170edb..efa77b8c5 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -86,7 +86,11 @@ class HttpTransactionCache(object): pass # execute the function instead. deferred = fn(*args, **kwargs) - observable = ObservableDeferred(deferred) + + # We don't add an errback to the raw deferred, so we ask ObservableDeferred + # to swallow the error. This is fine as the error will still be reported + # to the observers. + observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe() From 6f5e41e420d3a928c59841640610e1ad7756121c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Jan 2017 12:52:11 +0000 Subject: [PATCH 09/45] README.rst: fix formatting Fix formatting blooper introduced in https://github.com/matrix-org/synapse/pull/1672 :/ --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index ba21c52ae..77e0b470a 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,7 @@ Installing prerequisites on openSUSE:: python-devel libffi-devel libopenssl-devel libjpeg62-devel Installing prerequisites on OpenBSD:: + doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ libxslt From 8b2fa382568373573d3b1d520e8ebc2ef39e2935 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:07:32 +0000 Subject: [PATCH 10/45] Split event auth code into seperate module --- synapse/api/auth.py | 654 +----------------------------------------- synapse/event_auth.py | 641 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 650 insertions(+), 645 deletions(-) create mode 100644 synapse/event_auth.py diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5e2b89c32..b781d41a6 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -16,16 +16,13 @@ import logging import pymacaroons -from canonicaljson import encode_canonical_json -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json, SignatureVerifyException from twisted.internet import defer -from unpaddedbase64 import decode_base64 import synapse.types +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules -from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import UserID, get_domain_from_id +from synapse.api.errors import AuthError, Codes +from synapse.types import UserID from synapse.util.logcontext import preserve_context_over_fn from synapse.util.metrics import Measure @@ -42,622 +39,6 @@ AuthEventTypes = ( GUEST_DEVICE_ID = "guest_device" -class Auther(object): - @staticmethod - def check(event, auth_events, do_sig_check=True): - """ Checks if this event is correctly authed. - - Args: - event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. - - - Returns: - True if the auth checks pass. - """ - Auther.check_size_limits(event) - - if not hasattr(event, "room_id"): - raise AuthError(500, "Event has no room_id: %s" % event) - - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - event_id_domain = get_domain_from_id(event.event_id) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - # Check the event_id's domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return True - - if event.type == EventTypes.Create: - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" - ) - # FIXME - return True - - creation_event = auth_events.get((EventTypes.Create, ""), None) - - if not creation_event: - raise SynapseError( - 403, - "Room %r does not exist" % (event.room_id,) - ) - - creating_domain = get_domain_from_id(event.room_id) - originating_domain = get_domain_from_id(event.sender) - if creating_domain != originating_domain: - if not Auther.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # FIXME: Temp hack - if event.type == EventTypes.Aliases: - if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) - if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) - sender_domain = get_domain_from_id(event.sender) - if event.state_key != sender_domain: - raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" - ) - return True - - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) - - if event.type == EventTypes.Member: - allowed = Auther.is_membership_change_allowed( - event, auth_events - ) - if allowed: - logger.debug("Allowing! %s", event) - else: - logger.debug("Denying! %s", event) - return allowed - - Auther.check_event_sender_in_room(event, auth_events) - - # Special case to allow m.room.third_party_invite events wherever - # a user is allowed to issue invites. Fixes - # https://github.com/vector-im/vector-web/issues/1208 hopefully - if event.type == EventTypes.ThirdPartyInvite: - user_level = Auther._get_user_power_level(event.user_id, auth_events) - invite_level = Auther._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, ( - "You cannot issue a third party invite for %s." % - (event.content.display_name,) - ) - ) - else: - return True - - Auther._can_send_event(event, auth_events) - - if event.type == EventTypes.PowerLevels: - Auther._check_power_levels(event, auth_events) - - if event.type == EventTypes.Redaction: - Auther.check_redaction(event, auth_events) - - logger.debug("Allowing! %s", event) - - @staticmethod - def check_size_limits(event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") - - @staticmethod - def can_federate(event, auth_events): - creation_event = auth_events.get((EventTypes.Create, "")) - - return creation_event.content.get("m.federate", True) is True - - @staticmethod - def is_membership_change_allowed(event, auth_events): - membership = event.content["membership"] - - # Check if this is the room creator joining: - if len(event.prev_events) == 1 and Membership.JOIN == membership: - # Get room creation event: - key = (EventTypes.Create, "", ) - create = auth_events.get(key) - if create and event.prev_events[0][0] == create.event_id: - if create.content["creator"] == event.state_key: - return True - - target_user_id = event.state_key - - creating_domain = get_domain_from_id(event.room_id) - target_domain = get_domain_from_id(target_user_id) - if creating_domain != target_domain: - if not Auther.can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) - - # get info about the caller - key = (EventTypes.Member, event.user_id, ) - caller = auth_events.get(key) - - caller_in_room = caller and caller.membership == Membership.JOIN - caller_invited = caller and caller.membership == Membership.INVITE - - # get info about the target - key = (EventTypes.Member, target_user_id, ) - target = auth_events.get(key) - - target_in_room = target and target.membership == Membership.JOIN - target_banned = target and target.membership == Membership.BAN - - key = (EventTypes.JoinRules, "", ) - join_rule_event = auth_events.get(key) - if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) - else: - join_rule = JoinRules.INVITE - - user_level = Auther._get_user_power_level(event.user_id, auth_events) - target_level = Auther._get_user_power_level( - target_user_id, auth_events - ) - - # FIXME (erikj): What should we do here as the default? - ban_level = Auther._get_named_level(auth_events, "ban", 50) - - logger.debug( - "is_membership_change_allowed: %s", - { - "caller_in_room": caller_in_room, - "caller_invited": caller_invited, - "target_banned": target_banned, - "target_in_room": target_in_room, - "membership": membership, - "join_rule": join_rule, - "target_user_id": target_user_id, - "event.user_id": event.user_id, - } - ) - - if Membership.INVITE == membership and "third_party_invite" in event.content: - if not Auther._verify_third_party_invite(event, auth_events): - raise AuthError(403, "You are not invited to this room.") - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - return True - - if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): - return True - - if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) - - if Membership.INVITE == membership: - # TODO (erikj): We should probably handle this more intelligently - # PRIVATE join rules. - - # Invites are valid iff caller is in the room and target isn't. - if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) - elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) - else: - invite_level = Auther._get_named_level(auth_events, "invite", 0) - - if user_level < invite_level: - raise AuthError( - 403, "You cannot invite user %s." % target_user_id - ) - elif Membership.JOIN == membership: - # Joins are valid iff caller == target and they were: - # invited: They are accepting the invitation - # joined: It's a NOOP - if event.user_id != target_user_id: - raise AuthError(403, "Cannot force another user to join.") - elif target_banned: - raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC: - pass - elif join_rule == JoinRules.INVITE: - if not caller_in_room and not caller_invited: - raise AuthError(403, "You are not invited to this room.") - else: - # TODO (erikj): may_join list - # TODO (erikj): private rooms - raise AuthError(403, "You are not allowed to join this room") - elif Membership.LEAVE == membership: - # TODO (erikj): Implement kicks. - if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user &s." % (target_user_id,) - ) - elif target_user_id != event.user_id: - kick_level = Auther._get_named_level(auth_events, "kick", 50) - - if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) - elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: - raise AuthError(403, "You don't have permission to ban") - else: - raise AuthError(500, "Unknown membership %s" % membership) - - return True - - @staticmethod - def check_event_sender_in_room(event, auth_events): - key = (EventTypes.Member, event.user_id, ) - member_event = auth_events.get(key) - - return Auther._check_joined_room( - member_event, - event.user_id, - event.room_id - ) - - @staticmethod - def _check_joined_room(member, user_id, room_id): - if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) - - @staticmethod - def _get_send_level(etype, state_key, auth_events): - key = (EventTypes.PowerLevels, "", ) - send_level_event = auth_events.get(key) - send_level = None - if send_level_event: - send_level = send_level_event.content.get("events", {}).get( - etype - ) - if send_level is None: - if state_key is not None: - send_level = send_level_event.content.get( - "state_default", 50 - ) - else: - send_level = send_level_event.content.get( - "events_default", 0 - ) - - if send_level: - send_level = int(send_level) - else: - send_level = 0 - - return send_level - - @staticmethod - def _can_send_event(event, auth_events): - send_level = Auther._get_send_level( - event.type, event.get("state_key", None), auth_events - ) - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - if user_level < send_level: - raise AuthError( - 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) - ) - - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) - - return True - - @staticmethod - def check_redaction(event, auth_events): - """Check whether the event sender is allowed to redact the target event. - - Returns: - True if the the sender is allowed to redact the target event if the - target event was created by them. - False if the sender is allowed to redact the target event with no - further checks. - - Raises: - AuthError if the event sender is definitely not allowed to redact - the target event. - """ - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - redact_level = Auther._get_named_level(auth_events, "redact", 50) - - if user_level >= redact_level: - return False - - redacter_domain = get_domain_from_id(event.event_id) - redactee_domain = get_domain_from_id(event.redacts) - if redacter_domain == redactee_domain: - return True - - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - @staticmethod - def _check_power_levels(event, auth_events): - user_list = event.content.get("users", {}) - # Validate users - for k, v in user_list.items(): - try: - UserID.from_string(k) - except: - raise SynapseError(400, "Not a valid user_id: %s" % (k,)) - - try: - int(v) - except: - raise SynapseError(400, "Not a valid power level: %s" % (v,)) - - key = (event.type, event.state_key, ) - current_state = auth_events.get(key) - - if not current_state: - return - - user_level = Auther._get_user_power_level(event.user_id, auth_events) - - # Check other levels: - levels_to_check = [ - ("users_default", None), - ("events_default", None), - ("state_default", None), - ("ban", None), - ("redact", None), - ("kick", None), - ("invite", None), - ] - - old_list = current_state.content.get("users") - for user in set(old_list.keys() + user_list.keys()): - levels_to_check.append( - (user, "users") - ) - - old_list = current_state.content.get("events") - new_list = event.content.get("events") - for ev_id in set(old_list.keys() + new_list.keys()): - levels_to_check.append( - (ev_id, "events") - ) - - old_state = current_state.content - new_state = event.content - - for level_to_check, dir in levels_to_check: - old_loc = old_state - new_loc = new_state - if dir: - old_loc = old_loc.get(dir, {}) - new_loc = new_loc.get(dir, {}) - - if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) - else: - old_level = None - - if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) - else: - new_level = None - - if new_level is not None and old_level is not None: - if new_level == old_level: - continue - - if dir == "users" and level_to_check != event.user_id: - if old_level == user_level: - raise AuthError( - 403, - "You don't have permission to remove ops level equal " - "to your own" - ) - - if old_level > user_level or new_level > user_level: - raise AuthError( - 403, - "You don't have permission to add ops level greater " - "than your own" - ) - - @staticmethod - def _get_power_level_event(auth_events): - key = (EventTypes.PowerLevels, "", ) - return auth_events.get(key) - - @staticmethod - def _get_user_power_level(user_id, auth_events): - power_level_event = Auther._get_power_level_event(auth_events) - - if power_level_event: - level = power_level_event.content.get("users", {}).get(user_id) - if not level: - level = power_level_event.content.get("users_default", 0) - - if level is None: - return 0 - else: - return int(level) - else: - key = (EventTypes.Create, "", ) - create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): - return 100 - else: - return 0 - - @staticmethod - def _get_named_level(auth_events, name, default): - power_level_event = Auther._get_power_level_event(auth_events) - - if not power_level_event: - return default - - level = power_level_event.content.get(name, None) - if level is not None: - return int(level) - else: - return default - - @staticmethod - def _verify_third_party_invite(event, auth_events): - """ - Validates that the invite event is authorized by a previous third-party invite. - - Checks that the public key, and keyserver, match those in the third party invite, - and that the invite event has a signature issued using that public key. - - Args: - event: The m.room.member join event being validated. - auth_events: All relevant previous context events which may be used - for authorization decisions. - - Return: - True if the event fulfills the expectations of a previous third party - invite event. - """ - if "third_party_invite" not in event.content: - return False - if "signed" not in event.content["third_party_invite"]: - return False - signed = event.content["third_party_invite"]["signed"] - for key in {"mxid", "token"}: - if key not in signed: - return False - - token = signed["token"] - - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) - if not invite_event: - return False - - if invite_event.sender != event.sender: - return False - - if event.user_id != invite_event.user_id: - return False - - if signed["mxid"] != event.state_key: - return False - if signed["token"] != token: - return False - - for public_key_object in Auther.get_public_keys(invite_event): - public_key = public_key_object["public_key"] - try: - for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): - if not key_name.startswith("ed25519:"): - continue - verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) - ) - verify_signed_json(signed, server, verify_key) - - # We got the public key from the invite, so we know that the - # correct server signed the signed bundle. - # The caller is responsible for checking that the signing - # server has not revoked that public key. - return True - except (KeyError, SignatureVerifyException,): - continue - return False - - @staticmethod - def get_public_keys(invite_event): - public_keys = [] - if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } - if "key_validity_url" in invite_event.content: - o["key_validity_url"] = invite_event.content["key_validity_url"] - public_keys.append(o) - public_keys.extend(invite_event.content.get("public_keys", [])) - return public_keys - - class Auth(object): """ FIXME: This class contains a mix of functions for authenticating users @@ -693,24 +74,7 @@ class Auth(object): True if the auth checks pass. """ with Measure(self.clock, "auth.check"): - Auther.check(event, auth_events, do_sig_check=do_sig_check) - - def check_size_limits(self, event): - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - - if len(event.user_id) > 255: - too_big("user_id") - if len(event.room_id) > 255: - too_big("room_id") - if event.is_state() and len(event.state_key) > 255: - too_big("state_key") - if len(event.type) > 255: - too_big("type") - if len(event.event_id) > 255: - too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: - too_big("event") + event_auth.check(event, auth_events, do_sig_check=do_sig_check) @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): @@ -804,7 +168,7 @@ class Auth(object): return creation_event.content.get("m.federate", True) is True def get_public_keys(self, invite_event): - return Auther.get_public_keys(invite_event) + return event_auth.get_public_keys(invite_event) @defer.inlineCallbacks def get_user_by_req(self, request, allow_guest=False, rights="access"): @@ -1198,7 +562,7 @@ class Auth(object): defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): - return Auther._get_send_level(etype, state_key, auth_events) + return event_auth._get_send_level(etype, state_key, auth_events) def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. @@ -1213,7 +577,7 @@ class Auth(object): AuthError if the event sender is definitely not allowed to redact the target event. """ - return Auther.check_redaction(event, auth_events) + return event_auth.check_redaction(event, auth_events) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -1243,10 +607,10 @@ class Auth(object): if power_level_event: auth_events[(EventTypes.PowerLevels, "")] = power_level_event - send_level = Auther._get_send_level( + send_level = event_auth.get_send_level( EventTypes.Aliases, "", auth_events ) - user_level = Auther._get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level(user_id, auth_events) if user_level < send_level: raise AuthError( diff --git a/synapse/event_auth.py b/synapse/event_auth.py new file mode 100644 index 000000000..983d8e9a8 --- /dev/null +++ b/synapse/event_auth.py @@ -0,0 +1,641 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from canonicaljson import encode_canonical_json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json, SignatureVerifyException +from unpaddedbase64 import decode_base64 + +from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.errors import AuthError, SynapseError, EventSizeError +from synapse.types import UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + + +def check(event, auth_events, do_sig_check=True): + """ Checks if this event is correctly authed. + + Args: + event: the event being checked. + auth_events (dict: event-key -> event): the existing room state. + + + Returns: + True if the auth checks pass. + """ + _check_size_limits(event) + + if not hasattr(event, "room_id"): + raise AuthError(500, "Event has no room_id: %s" % event) + + if do_sig_check: + sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + if auth_events is None: + # Oh, we don't know what the state of the room was, so we + # are trusting that this is allowed (at least for now) + logger.warn("Trusting event: %s", event.event_id) + return True + + if event.type == EventTypes.Create: + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, + "Creation event's room_id domain does not match sender's" + ) + # FIXME + return True + + creation_event = auth_events.get((EventTypes.Create, ""), None) + + if not creation_event: + raise SynapseError( + 403, + "Room %r does not exist" % (event.room_id,) + ) + + creating_domain = get_domain_from_id(event.room_id) + originating_domain = get_domain_from_id(event.sender) + if creating_domain != originating_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # FIXME: Temp hack + if event.type == EventTypes.Aliases: + if not event.is_state(): + raise AuthError( + 403, + "Alias event must be a state event", + ) + if not event.state_key: + raise AuthError( + 403, + "Alias event must have non-empty state_key" + ) + sender_domain = get_domain_from_id(event.sender) + if event.state_key != sender_domain: + raise AuthError( + 403, + "Alias event's state_key does not match sender's domain" + ) + return True + + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) + + if event.type == EventTypes.Member: + allowed = _is_membership_change_allowed( + event, auth_events + ) + if allowed: + logger.debug("Allowing! %s", event) + else: + logger.debug("Denying! %s", event) + return allowed + + _check_event_sender_in_room(event, auth_events) + + # Special case to allow m.room.third_party_invite events wherever + # a user is allowed to issue invites. Fixes + # https://github.com/vector-im/vector-web/issues/1208 hopefully + if event.type == EventTypes.ThirdPartyInvite: + user_level = get_user_power_level(event.user_id, auth_events) + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, ( + "You cannot issue a third party invite for %s." % + (event.content.display_name,) + ) + ) + else: + return True + + _can_send_event(event, auth_events) + + if event.type == EventTypes.PowerLevels: + _check_power_levels(event, auth_events) + + if event.type == EventTypes.Redaction: + check_redaction(event, auth_events) + + logger.debug("Allowing! %s", event) + + +def _check_size_limits(event): + def too_big(field): + raise EventSizeError("%s too large" % (field,)) + + if len(event.user_id) > 255: + too_big("user_id") + if len(event.room_id) > 255: + too_big("room_id") + if event.is_state() and len(event.state_key) > 255: + too_big("state_key") + if len(event.type) > 255: + too_big("type") + if len(event.event_id) > 255: + too_big("event_id") + if len(encode_canonical_json(event.get_pdu_json())) > 65536: + too_big("event") + + +def _can_federate(event, auth_events): + creation_event = auth_events.get((EventTypes.Create, "")) + + return creation_event.content.get("m.federate", True) is True + + +def _is_membership_change_allowed(event, auth_events): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (EventTypes.Create, "", ) + create = auth_events.get(key) + if create and event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + + target_user_id = event.state_key + + creating_domain = get_domain_from_id(event.room_id) + target_domain = get_domain_from_id(target_user_id) + if creating_domain != target_domain: + if not _can_federate(event, auth_events): + raise AuthError( + 403, + "This room has been marked as unfederatable." + ) + + # get info about the caller + key = (EventTypes.Member, event.user_id, ) + caller = auth_events.get(key) + + caller_in_room = caller and caller.membership == Membership.JOIN + caller_invited = caller and caller.membership == Membership.INVITE + + # get info about the target + key = (EventTypes.Member, target_user_id, ) + target = auth_events.get(key) + + target_in_room = target and target.membership == Membership.JOIN + target_banned = target and target.membership == Membership.BAN + + key = (EventTypes.JoinRules, "", ) + join_rule_event = auth_events.get(key) + if join_rule_event: + join_rule = join_rule_event.content.get( + "join_rule", JoinRules.INVITE + ) + else: + join_rule = JoinRules.INVITE + + user_level = get_user_power_level(event.user_id, auth_events) + target_level = get_user_power_level( + target_user_id, auth_events + ) + + # FIXME (erikj): What should we do here as the default? + ban_level = _get_named_level(auth_events, "ban", 50) + + logger.debug( + "_is_membership_change_allowed: %s", + { + "caller_in_room": caller_in_room, + "caller_invited": caller_invited, + "target_banned": target_banned, + "target_in_room": target_in_room, + "membership": membership, + "join_rule": join_rule, + "target_user_id": target_user_id, + "event.user_id": event.user_id, + } + ) + + if Membership.INVITE == membership and "third_party_invite" in event.content: + if not _verify_third_party_invite(event, auth_events): + raise AuthError(403, "You are not invited to this room.") + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + return True + + if Membership.JOIN != membership: + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + + if not caller_in_room: # caller isn't joined + raise AuthError( + 403, + "%s not in room %s." % (event.user_id, event.room_id,) + ) + + if Membership.INVITE == membership: + # TODO (erikj): We should probably handle this more intelligently + # PRIVATE join rules. + + # Invites are valid iff caller is in the room and target isn't. + if target_banned: + raise AuthError( + 403, "%s is banned from the room" % (target_user_id,) + ) + elif target_in_room: # the target is already in the room. + raise AuthError(403, "%s is already in the room." % + target_user_id) + else: + invite_level = _get_named_level(auth_events, "invite", 0) + + if user_level < invite_level: + raise AuthError( + 403, "You cannot invite user %s." % target_user_id + ) + elif Membership.JOIN == membership: + # Joins are valid iff caller == target and they were: + # invited: They are accepting the invitation + # joined: It's a NOOP + if event.user_id != target_user_id: + raise AuthError(403, "Cannot force another user to join.") + elif target_banned: + raise AuthError(403, "You are banned from this room") + elif join_rule == JoinRules.PUBLIC: + pass + elif join_rule == JoinRules.INVITE: + if not caller_in_room and not caller_invited: + raise AuthError(403, "You are not invited to this room.") + else: + # TODO (erikj): may_join list + # TODO (erikj): private rooms + raise AuthError(403, "You are not allowed to join this room") + elif Membership.LEAVE == membership: + # TODO (erikj): Implement kicks. + if target_banned and user_level < ban_level: + raise AuthError( + 403, "You cannot unban user &s." % (target_user_id,) + ) + elif target_user_id != event.user_id: + kick_level = _get_named_level(auth_events, "kick", 50) + + if user_level < kick_level or user_level <= target_level: + raise AuthError( + 403, "You cannot kick user %s." % target_user_id + ) + elif Membership.BAN == membership: + if user_level < ban_level or user_level <= target_level: + raise AuthError(403, "You don't have permission to ban") + else: + raise AuthError(500, "Unknown membership %s" % membership) + + return True + + +def _check_event_sender_in_room(event, auth_events): + key = (EventTypes.Member, event.user_id, ) + member_event = auth_events.get(key) + + return _check_joined_room( + member_event, + event.user_id, + event.room_id + ) + + +def _check_joined_room(member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s (%s)" % ( + user_id, room_id, repr(member) + )) + + +def get_send_level(etype, state_key, auth_events): + key = (EventTypes.PowerLevels, "", ) + send_level_event = auth_events.get(key) + send_level = None + if send_level_event: + send_level = send_level_event.content.get("events", {}).get( + etype + ) + if send_level is None: + if state_key is not None: + send_level = send_level_event.content.get( + "state_default", 50 + ) + else: + send_level = send_level_event.content.get( + "events_default", 0 + ) + + if send_level: + send_level = int(send_level) + else: + send_level = 0 + + return send_level + + +def _can_send_event(event, auth_events): + send_level = get_send_level( + event.type, event.get("state_key", None), auth_events + ) + user_level = get_user_power_level(event.user_id, auth_events) + + if user_level < send_level: + raise AuthError( + 403, + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level) + ) + + # Check state_key + if hasattr(event, "state_key"): + if event.state_key.startswith("@"): + if event.state_key != event.user_id: + raise AuthError( + 403, + "You are not allowed to set others state" + ) + + return True + + +def check_redaction(event, auth_events): + """Check whether the event sender is allowed to redact the target event. + + Returns: + True if the the sender is allowed to redact the target event if the + target event was created by them. + False if the sender is allowed to redact the target event with no + further checks. + + Raises: + AuthError if the event sender is definitely not allowed to redact + the target event. + """ + user_level = get_user_power_level(event.user_id, auth_events) + + redact_level = _get_named_level(auth_events, "redact", 50) + + if user_level >= redact_level: + return False + + redacter_domain = get_domain_from_id(event.event_id) + redactee_domain = get_domain_from_id(event.redacts) + if redacter_domain == redactee_domain: + return True + + raise AuthError( + 403, + "You don't have permission to redact events" + ) + + +def _check_power_levels(event, auth_events): + user_list = event.content.get("users", {}) + # Validate users + for k, v in user_list.items(): + try: + UserID.from_string(k) + except: + raise SynapseError(400, "Not a valid user_id: %s" % (k,)) + + try: + int(v) + except: + raise SynapseError(400, "Not a valid power level: %s" % (v,)) + + key = (event.type, event.state_key, ) + current_state = auth_events.get(key) + + if not current_state: + return + + user_level = get_user_power_level(event.user_id, auth_events) + + # Check other levels: + levels_to_check = [ + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), + ] + + old_list = current_state.content.get("users") + for user in set(old_list.keys() + user_list.keys()): + levels_to_check.append( + (user, "users") + ) + + old_list = current_state.content.get("events") + new_list = event.content.get("events") + for ev_id in set(old_list.keys() + new_list.keys()): + levels_to_check.append( + (ev_id, "events") + ) + + old_state = current_state.content + new_state = event.content + + for level_to_check, dir in levels_to_check: + old_loc = old_state + new_loc = new_state + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) + + if level_to_check in old_loc: + old_level = int(old_loc[level_to_check]) + else: + old_level = None + + if level_to_check in new_loc: + new_level = int(new_loc[level_to_check]) + else: + new_level = None + + if new_level is not None and old_level is not None: + if new_level == old_level: + continue + + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + + if old_level > user_level or new_level > user_level: + raise AuthError( + 403, + "You don't have permission to add ops level greater " + "than your own" + ) + + +def _get_power_level_event(auth_events): + key = (EventTypes.PowerLevels, "", ) + return auth_events.get(key) + + +def get_user_power_level(user_id, auth_events): + power_level_event = _get_power_level_event(auth_events) + + if power_level_event: + level = power_level_event.content.get("users", {}).get(user_id) + if not level: + level = power_level_event.content.get("users_default", 0) + + if level is None: + return 0 + else: + return int(level) + else: + key = (EventTypes.Create, "", ) + create_event = auth_events.get(key) + if (create_event is not None and + create_event.content["creator"] == user_id): + return 100 + else: + return 0 + + +def _get_named_level(auth_events, name, default): + power_level_event = _get_power_level_event(auth_events) + + if not power_level_event: + return default + + level = power_level_event.content.get(name, None) + if level is not None: + return int(level) + else: + return default + + +def _verify_third_party_invite(event, auth_events): + """ + Validates that the invite event is authorized by a previous third-party invite. + + Checks that the public key, and keyserver, match those in the third party invite, + and that the invite event has a signature issued using that public key. + + Args: + event: The m.room.member join event being validated. + auth_events: All relevant previous context events which may be used + for authorization decisions. + + Return: + True if the event fulfills the expectations of a previous third party + invite event. + """ + if "third_party_invite" not in event.content: + return False + if "signed" not in event.content["third_party_invite"]: + return False + signed = event.content["third_party_invite"]["signed"] + for key in {"mxid", "token"}: + if key not in signed: + return False + + token = signed["token"] + + invite_event = auth_events.get( + (EventTypes.ThirdPartyInvite, token,) + ) + if not invite_event: + return False + + if invite_event.sender != event.sender: + return False + + if event.user_id != invite_event.user_id: + return False + + if signed["mxid"] != event.state_key: + return False + if signed["token"] != token: + return False + + for public_key_object in get_public_keys(invite_event): + public_key = public_key_object["public_key"] + try: + for server, signature_block in signed["signatures"].items(): + for key_name, encoded_signature in signature_block.items(): + if not key_name.startswith("ed25519:"): + continue + verify_key = decode_verify_key_bytes( + key_name, + decode_base64(public_key) + ) + verify_signed_json(signed, server, verify_key) + + # We got the public key from the invite, so we know that the + # correct server signed the signed bundle. + # The caller is responsible for checking that the signing + # server has not revoked that public key. + return True + except (KeyError, SignatureVerifyException,): + continue + return False + + +def get_public_keys(invite_event): + public_keys = [] + if "public_key" in invite_event.content: + o = { + "public_key": invite_event.content["public_key"], + } + if "key_validity_url" in invite_event.content: + o["key_validity_url"] = invite_event.content["key_validity_url"] + public_keys.append(o) + public_keys.extend(invite_event.content.get("public_keys", [])) + return public_keys From a3e4a198e3f5e0acd91d40d5743f97ece2cf5b6f Mon Sep 17 00:00:00 2001 From: Adrian Perez de Castro Date: Fri, 13 Jan 2017 17:12:04 +0200 Subject: [PATCH 11/45] Allow configuring the Riot URL used in notification emails The URLs used for notification emails were hardcoded to use either matrix.to or vector.im; but for self-hosted setups where Riot is also self-hosted it may be desirable to allow configuring an alternative Riot URL. Fixes #1809. Signed-off-by: Adrian Perez de Castro --- synapse/config/emailconfig.py | 7 +++++++ synapse/push/mailer.py | 20 ++++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a18716127..0030b5db1 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -68,6 +68,9 @@ class EmailConfig(Config): self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) + self.email_riot_base_url = email_config.get( + "riot_base_url", None + ) if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -85,6 +88,9 @@ class EmailConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable sending emails for notification events + # Defining a custom URL for Riot is only needed if email notifications + # should contain links to a self-hosted installation of Riot; when set + # the "app_name" setting is ignored. #email: # enable_notifs: false # smtp_host: "localhost" @@ -95,4 +101,5 @@ class EmailConfig(Config): # notif_template_html: notif_mail.html # notif_template_text: notif_mail.txt # notif_for_new_users: True + # riot_base_url: "http://localhost/riot" """ diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 53551632b..ce2d31fb9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -439,15 +439,23 @@ class Mailer(object): }) def make_room_link(self, room_id): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": - return "https://vector.im/beta/#/room/%s" % (room_id,) + if self.hs.config.email_riot_base_url: + base_url = self.hs.config.email_riot_base_url + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS + base_url = "https://vector.im/beta/#/room" else: - return "https://matrix.to/#/%s" % (room_id,) + base_url = "https://matrix.to/#" + return "%s/%s" % (base_url, room_id) def make_notif_link(self, notif): - # need /beta for Universal Links to work on iOS - if self.app_name == "Vector": + if self.hs.config.email_riot_base_url: + return "%s/#/room/%s/%s" % ( + self.hs.config.email_riot_base_url, + notif['room_id'], notif['event_id'] + ) + elif self.app_name == "Vector": + # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( notif['room_id'], notif['event_id'] ) From c050f493dd53a74206338f9a5e567d7bd24fbd5d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:14:41 +0000 Subject: [PATCH 12/45] Add comment --- synapse/storage/schema/delta/40/device_inbox.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql index ce58fe208..b9fe1f048 100644 --- a/synapse/storage/schema/delta/40/device_inbox.sql +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -13,6 +13,7 @@ * limitations under the License. */ +-- turn the pre-fill startup query into a index-only scan on postgresql. INSERT into background_updates (update_name, progress_json) VALUES ('device_inbox_stream_index', '{}'); From e178feca3f9063c7a4f768298e889ee54b471e9b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 15:16:45 +0000 Subject: [PATCH 13/45] Remove unused function --- synapse/api/auth.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b781d41a6..280d4c445 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -561,9 +561,6 @@ class Auth(object): defer.returnValue(auth_ids) - def _get_send_level(self, etype, state_key, auth_events): - return event_auth._get_send_level(etype, state_key, auth_events) - def check_redaction(self, event, auth_events): """Check whether the event sender is allowed to redact the target event. From ec0a523ac338bab1eb23a6b21227b8f7402cc2d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:37:18 +0000 Subject: [PATCH 14/45] Split out static state methods from StateHandler --- synapse/state.py | 142 ++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a8..c75499c3e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,6 +16,7 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure @@ -335,9 +336,10 @@ class StateHandler(object): [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) + with Measure(self.clock, "state._resolve_events"): + new_state, _ = Resolver.resolve_events( + state_sets, event_type, state_key + ) new_state = { key: e.event_id for key, e in new_state.items() } @@ -388,68 +390,78 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + with Measure(self.clock, "state._resolve_events"): + if event.is_state(): + return Resolver.resolve_events( + state_sets, event.type, event.state_key + ) + else: + return Resolver.resolve_events(state_sets) - def _resolve_events(self, state_sets, event_type=None, state_key=""): + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) + + +class Resolver(object): + @staticmethod + def resolve_events(state_sets, event_type=None, state_key=""): """ Returns (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple (new_state, prev_states). new_state is a map from (type, state_key) to event. prev_states is a list of event_ids. """ - with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + try: + resolved_state = Resolver._resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - new_state = unconflicted_state - new_state.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): + @staticmethod + def _resolve_state_events(conflicted_state, auth_events): """ This is where we actually decide which of the conflicted state to use. @@ -464,7 +476,7 @@ class StateHandler(object): if power_key in conflicted_state: events = conflicted_state[power_key] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( + resolved_state[power_key] = Resolver._resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -472,7 +484,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -482,7 +494,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -492,14 +504,15 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( + resolved_state[key] = Resolver._resolve_normal_events( events, auth_events ) return resolved_state - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] + @staticmethod + def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] auth_events = dict(auth_events) @@ -507,23 +520,20 @@ class StateHandler(object): for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event return event - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): + @staticmethod + def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) return event except AuthError: pass @@ -531,9 +541,3 @@ class StateHandler(object): # Use the last event (the one with the least depth) if they all fail # the auth check. return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func) From 2fae34bd2ce152b8544d5a90fe3b35281c5fffbc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 17:46:17 +0000 Subject: [PATCH 15/45] Optionally measure size of cache by sum of length of values --- synapse/storage/roommember.py | 3 ++- synapse/storage/state.py | 2 +- synapse/util/caches/descriptors.py | 25 ++++++++++++++++++----- synapse/util/caches/lrucache.py | 32 +++++++++++++++++------------- tests/util/test_lrucache.py | 25 +++++++++++++++++++++++ 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7..e63aab6cc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=2000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40a..c480743f8 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=1000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49..d082c26b1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -42,6 +42,13 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +def deferred_size(deferred): + if deferred.called: + return len(deferred.result) + else: + return 1 + + class Cache(object): __slots__ = ( "cache", @@ -53,10 +60,11 @@ class Cache(object): "metrics", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=deferred_size if iterable else None, ) self.name = name @@ -155,7 +163,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +177,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +213,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c67917..00ddf3829 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,18 @@ class LruCache(object): lock = threading.Lock() + def cache_len(): + if size_callback is not None: + return sum(size_callback(node.value) for node in cache.itervalues()) + else: + return len(cache) + + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -127,22 +139,18 @@ class LruCache(object): else: callbacks = set() add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + + evict() @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: + evict() # As the new node may be bigger than the old node. return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -175,10 +183,6 @@ class LruCache(object): cb() cache.clear() - @synchronized - def cache_len(): - return len(cache) - @synchronized def cache_contains(key): return key in cache @@ -190,7 +194,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 1eba5b535..d888a64d0 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -232,3 +232,28 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 1) + + +class LruCacheSizedTestCase(unittest.TestCase): + + def test_evict(self): + cache = LruCache(5, size_callback=len) + cache["key1"] = [0] + cache["key2"] = [1, 2] + cache["key3"] = [3] + cache["key4"] = [4] + + self.assertEquals(cache["key1"], [0]) + self.assertEquals(cache["key2"], [1, 2]) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(len(cache), 5) + + cache["key5"] = [5, 6] + + self.assertEquals(len(cache), 4) + self.assertEquals(cache.get("key1"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(cache["key5"], [5, 6]) From 01521299c7d6d65b0f8b567bc7b7dbf94b7a81ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 11:56:51 +0000 Subject: [PATCH 16/45] Increase cache size limit --- synapse/storage/roommember.py | 2 +- synapse/storage/state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index e63aab6cc..8dce89073 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -391,7 +391,7 @@ class RoomMemberStore(SQLBaseStore): ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=2000) + max_entries=50000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c480743f8..fe942ecad 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000, iterable=True) + @cached(num_args=2, max_entries=50000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() From 46aebbbcbf94eb78ae45d3bb3bf3ffeabb44dd4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 13:48:04 +0000 Subject: [PATCH 17/45] Add support for 'iterable' to ExpiringCache --- synapse/state.py | 6 +++++- synapse/util/caches/expiringcache.py | 26 +++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a8..461e82acd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(10000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -77,6 +77,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -99,6 +102,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958..9b44b3fab 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +36,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -49,7 +51,9 @@ class ExpiringCache(object): self._cache = {} - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable def start(self): if not self._expiry_ms: @@ -66,14 +70,15 @@ class ExpiringCache(object): self._cache[key] = _CacheEntry(now, value) # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: + if self._max_len and len(self) > self._max_len: sorted_entries = sorted( - self._cache.items(), + self._cache.keys(), key=lambda item: item[1].time, ) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + while len(self) > self._max_len and sorted_entries: + key = sorted_entries.pop() + self._cache.pop(key) def __getitem__(self, key): try: @@ -99,7 +104,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -114,11 +119,14 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return sum(len(value.value) for value in self._cache.itervalues()) + else: + return len(self._cache) class _CacheEntry(object): From beda469bc6e96a0b776c3d6742cf97950819b2f0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:05:24 +0000 Subject: [PATCH 18/45] Put staticmethods at module level --- synapse/state.py | 252 +++++++++++++++++++++++------------------------ 1 file changed, 125 insertions(+), 127 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index c75499c3e..90b14e758 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -337,7 +337,7 @@ class StateHandler(object): for st in state_groups_ids.values() ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = Resolver.resolve_events( + new_state, _ = resolve_events( state_sets, event_type, state_key ) new_state = { @@ -392,11 +392,11 @@ class StateHandler(object): ) with Measure(self.clock, "state._resolve_events"): if event.is_state(): - return Resolver.resolve_events( + return resolve_events( state_sets, event.type, event.state_key ) else: - return Resolver.resolve_events(state_sets) + return resolve_events(state_sets) def _ordered_events(events): @@ -406,138 +406,136 @@ def _ordered_events(events): return sorted(events, key=key_func) -class Resolver(object): - @staticmethod - def resolve_events(state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e +def resolve_events(state_sets, event_type=None, state_key=""): + """ + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. + """ + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] + + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state + new_state.update(resolved_state) + + return new_state, prev_states + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events.update(resolved_state) + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - resolved_state = Resolver._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + prev_event = event + except AuthError: + return prev_event - new_state = unconflicted_state - new_state.update(resolved_state) + return event - return new_state, prev_states - @staticmethod - def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + return event + except AuthError: + pass - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = Resolver._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = Resolver._resolve_normal_events( - events, auth_events - ) - - return resolved_state - - @staticmethod - def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - @staticmethod - def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event From 897f8752da3c9f7b2d214fe91e8356be5db545c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:08:17 +0000 Subject: [PATCH 19/45] Up cache max entries for state --- synapse/state.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/storage/state.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 461e82acd..66e1a685e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(10000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8dce89073..768e0a445 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -391,7 +391,7 @@ class RoomMemberStore(SQLBaseStore): ) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=50000) + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fe942ecad..7d34dd03b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=50000, iterable=True) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() From 6d00213e80fa51380c8ad7b339e7420edec27f9a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:33:22 +0000 Subject: [PATCH 20/45] Use OrderedDict in ExpiringCache --- synapse/util/caches/expiringcache.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 9b44b3fab..b9ead9cbd 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -49,7 +50,7 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() self.metrics = register_cache(cache_name, self) @@ -70,15 +71,8 @@ class ExpiringCache(object): self._cache[key] = _CacheEntry(now, value) # Evict if there are now too many items - if self._max_len and len(self) > self._max_len: - sorted_entries = sorted( - self._cache.keys(), - key=lambda item: item[1].time, - ) - - while len(self) > self._max_len and sorted_entries: - key = sorted_entries.pop() - self._cache.pop(key) + while self._max_len and len(self) > self._max_len: + self._cache.popitem(last=False) def __getitem__(self, key): try: From f2f179dce26f42ea0e691d17c60b297c63898923 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:33:34 +0000 Subject: [PATCH 21/45] Add ExpiringCache tests --- tests/util/test_expiring_cache.py | 84 +++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/util/test_expiring_cache.py diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py new file mode 100644 index 000000000..31d24adb8 --- /dev/null +++ b/tests/util/test_expiring_cache.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.expiringcache import ExpiringCache + +from tests.utils import MockClock + + +class ExpiringCacheTestCase(unittest.TestCase): + + def test_get_set(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=1) + + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") + + def test_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=2) + + cache["key"] = "value" + cache["key2"] = "value2" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache.get("key2"), "value2") + + cache["key3"] = "value3" + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), "value2") + self.assertEquals(cache.get("key3"), "value3") + + def test_iterable_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=5, iterable=True) + + cache["key"] = [1] + cache["key2"] = [2, 3] + cache["key3"] = [4, 5] + + self.assertEquals(cache.get("key"), [1]) + self.assertEquals(cache.get("key2"), [2, 3]) + self.assertEquals(cache.get("key3"), [4, 5]) + + cache["key4"] = [6, 7] + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEquals(cache.get("key4"), [6, 7]) + + def test_time_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, expiry_ms=1000) + cache.start() + + cache["key"] = 1 + clock.advance_time(0.5) + cache["key2"] = 2 + + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(0.9) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(1) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) From f85b6ca494ae587731d99196020cc74d7eca012a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:18:13 +0000 Subject: [PATCH 22/45] Speed up cache size calculation Instead of calculating the size of the cache repeatedly, which can take a long time now that it can use a callback, instead cache the size and update that on insertion and deletion. This requires changing the cache descriptors to have two caches, one for pending deferreds and the other for the actual values. There's no reason to evict from the pending deferreds as they won't take up any more memory. --- synapse/util/caches/descriptors.py | 97 +++++++++++++++++++------ synapse/util/caches/dictionary_cache.py | 6 +- synapse/util/caches/expiringcache.py | 15 +++- synapse/util/caches/lrucache.py | 42 ++++++----- synapse/util/caches/treecache.py | 14 +++- tests/storage/test__base.py | 6 +- tests/util/test_lrucache.py | 30 ++++---- 7 files changed, 148 insertions(+), 62 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d082c26b1..b3b2d6092 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, popped_to_iterator from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,11 +42,23 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -def deferred_size(deferred): - if deferred.called: - return len(deferred.result) - else: - return 1 +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() class Cache(object): @@ -58,13 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, - size_callback=deferred_size if iterable else None, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -84,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -96,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -116,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -127,6 +178,12 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + val = self._pending_deferred_cache.pop(key, None) + if val is not None: + entry_dict, _ = val + for entry in popped_to_iterator(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -254,11 +311,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -272,7 +324,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -370,7 +422,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -393,8 +444,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79..cb6933c61 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index b9ead9cbd..2987c38a2 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -56,6 +56,8 @@ class ExpiringCache(object): self.iterable = iterable + self._size_estimate = 0 + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -70,9 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + if self.iterable: + self._size_estimate += len(value) + # Evict if there are now too many items while self._max_len and len(self) > self._max_len: - self._cache.popitem(last=False) + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -109,7 +116,9 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -118,7 +127,7 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return sum(len(value.value) for value in self._cache.itervalues()) + return self._size_estimate else: return len(self._cache) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 00ddf3829..f1de03444 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -58,12 +58,6 @@ class LruCache(object): lock = threading.Lock() - def cache_len(): - if size_callback is not None: - return sum(size_callback(node.value) for node in cache.itervalues()) - else: - return len(cache) - def evict(): while cache_len() > max_size: todelete = list_root.prev_node @@ -78,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -86,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -104,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -128,17 +137,16 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) + add_node(key, value, set(callbacks)) evict() diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea..460e98a92 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,24 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(popped_to_iterator(self.root)) def __len__(self): return self.size +def popped_to_iterator(d): + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in popped_to_iterator(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ab6095564..8361dd8ce 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=2) + @cached(max_entries=20) # HACK: This makes it 2 due to cache factor def func(self, key): callcount[0] += 1 return key @@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 2) self.assertEquals(callcount2[0], 2) + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + yield a.func("foo3") self.assertEquals(callcount[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index d888a64d0..99aab6500 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.get("key", "value") @@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value2") @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", m1) - cache.set(("a", "2"), "value", m2) - cache.set(("b", "1"), "value", m3) - cache.set(("b", "2"), "value", m4) + cache.set(("a", "1"), "value", [m1]) + cache.set(("a", "2"), "value", [m2]) + cache.set(("b", "1"), "value", [m3]) + cache.set(("b", "2"), "value", [m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", m3) + cache.set("key3", "value", [m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", m1) + cache.set("key1", "value", [m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) From d9062060499d670f41ebc31d43003bed3502a722 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:25:51 +0000 Subject: [PATCH 23/45] Increase state_group_cache_size --- synapse/storage/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5620a655e..963ef999d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() From 1ccd5676e3fe01bcc1c59fd06f400f629b24c3ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:42:26 +0000 Subject: [PATCH 24/45] Remove needless call to evict() --- synapse/util/caches/lrucache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f1de03444..072f9a9d1 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -154,7 +154,6 @@ class LruCache(object): def cache_set_default(key, value): node = cache.get(key, None) if node is not None: - evict() # As the new node may be bigger than the old node. return node.value else: add_node(key, value) From d6c75cb7c237a31252f0838d2aa6114cd58b2ad4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:44:57 +0000 Subject: [PATCH 25/45] Rename and comment tree_to_leaves_iterator --- synapse/util/caches/descriptors.py | 4 ++-- synapse/util/caches/treecache.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3b2d6092..a9ea97fd4 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, popped_to_iterator +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -181,7 +181,7 @@ class Cache(object): val = self._pending_deferred_cache.pop(key, None) if val is not None: entry_dict, _ = val - for entry in popped_to_iterator(entry_dict): + for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() def invalidate_all(self): diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 460e98a92..fcc341a6b 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,16 +65,19 @@ class TreeCache(object): return popped def values(self): - return list(popped_to_iterator(self.root)) + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size -def popped_to_iterator(d): +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ if isinstance(d, dict): for value_d in d.itervalues(): - for value in popped_to_iterator(value_d): + for value in iterate_tree_cache_entry(value_d): yield value else: if isinstance(d, _Entry): From 9e8e236d9816ef639bdeb72cbb4de0fc29c6b120 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:48:02 +0000 Subject: [PATCH 26/45] Tidy up test --- tests/util/test_lrucache.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 99aab6500..dfb78cb8b 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", [m1]) - cache.set(("a", "2"), "value", [m2]) - cache.set(("b", "1"), "value", [m3]) - cache.set(("b", "2"), "value", [m4]) + cache.set(("a", "1"), "value", callbacks=[m1]) + cache.set(("a", "2"), "value", callbacks=[m2]) + cache.set(("b", "1"), "value", callbacks=[m3]) + cache.set(("b", "2"), "value", callbacks=[m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", [m3]) + cache.set("key3", "value", callbacks=[m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", [m1]) + cache.set("key1", "value", callbacks=[m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) From 5d6bad1b3c325897db81f84ebfc67ca687d851c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:16:54 +0000 Subject: [PATCH 27/45] Optimise state resolution --- synapse/event_auth.py | 49 +++++++- synapse/events/__init__.py | 8 +- synapse/events/builder.py | 6 +- synapse/handlers/federation.py | 2 +- synapse/state.py | 211 +++++++++++++++++++++++---------- tests/api/test_filtering.py | 5 +- tests/events/test_utils.py | 22 +++- 7 files changed, 230 insertions(+), 73 deletions(-) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 983d8e9a8..3b7726a52 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(event, auth_events, do_sig_check=True): +def check(event, auth_events, do_sig_check=True, do_size_check=True): """ Checks if this event is correctly authed. Args: @@ -38,7 +38,8 @@ def check(event, auth_events, do_sig_check=True): Returns: True if the auth checks pass. """ - _check_size_limits(event) + if do_size_check: + _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) @@ -119,10 +120,11 @@ def check(event, auth_events, do_sig_check=True): ) return True - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) if event.type == EventTypes.Member: allowed = _is_membership_change_allowed( @@ -639,3 +641,38 @@ def get_public_keys(invite_event): public_keys.append(o) public_keys.extend(invite_event.content.get("public_keys", [])) return public_keys + + +def auth_types_for_event(event): + """Given an event, return a list of (EventType, StateKey) that may be + needed to auth the event. The returned list may be a superset of what + would actually be required depending on the full state of the room. + + Used to limit the number of events to fetch from the database to + actually auth the event. + """ + if event.type == EventTypes.Create: + return [] + + auth_types = [] + + auth_types.append((EventTypes.PowerLevels, "", )) + auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Create, "", )) + + if event.type == EventTypes.Member: + e_type = event.content["membership"] + if e_type in [Membership.JOIN, Membership.INVITE]: + auth_types.append((EventTypes.JoinRules, "", )) + + auth_types.append((EventTypes.Member, event.state_key, )) + + if e_type == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"] + ) + auth_types.append(key) + + return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad43..e673e96cc 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -79,7 +79,6 @@ class EventBase(object): auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") content = _event_dict_property("content") - event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") @@ -88,8 +87,6 @@ class EventBase(object): redacts = _event_dict_property("redacts") room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") user_id = _event_dict_property("sender") @property @@ -162,6 +159,11 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict + self.event_id = event_dict["event_id"] + self.type = event_dict["type"] + if "state_key" in event_dict: + self.state_key = event_dict["state_key"] + super(FrozenEvent, self).__init__( frozen_dict, signatures=signatures, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7369d7098..365fd96bd 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent +from . import EventBase, FrozenEvent, _event_dict_property from synapse.types import EventID @@ -34,6 +34,10 @@ class EventBuilder(EventBase): internal_metadata_dict=internal_metadata_dict, ) + event_id = _event_dict_property("event_id") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") + def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc40..ea89e0cf2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state, prev_state = self.state_handler.resolve_events( + new_state = self.state_handler.resolve_events( [local_view.values(), remote_view.values()], event ) diff --git a/synapse/state.py b/synapse/state.py index 90b14e758..294e0c208 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -22,11 +22,10 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer -from collections import namedtuple +from collections import namedtuple, defaultdict from frozendict import frozendict import logging @@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60 _NEXT_STATE_ID = 1 +POWER_KEY = (EventTypes.PowerLevels, "") + def _gen_state_id(): global _NEXT_STATE_ID @@ -328,21 +329,13 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = resolve_events( - state_sets, event_type, state_key + new_state = yield resolve_events( + state_groups_ids.values(), + state_map_factory=lambda ev_ids: self.store.get_events( + ev_ids, get_prev_content=False + ), ) - new_state = { - key: e.event_id for key, e in new_state.items() - } else: new_state = { key: e_ids.pop() for key, e_ids in state.items() @@ -390,13 +383,25 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } + with Measure(self.clock, "state._resolve_events"): - if event.is_state(): - return resolve_events( - state_sets, event.type, event.state_key - ) - else: - return resolve_events(state_sets) + new_state = resolve_events(state_set_ids, state_map) + + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.items() + } + + return new_state def _ordered_events(events): @@ -406,43 +411,117 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, event_type=None, state_key=""): +def resolve_events(state_sets, state_map_factory): """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(dict|callable): If callable, then will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. Otherwise, should be + a dict from event_id to event of all events in state_sets. + Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. + dict[(str, str), synapse.events.FrozenEvent] is a map from + (type, state_key) to event. """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } - - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } - - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + if callable(state_map_factory): + return _resolve_with_state_fac( + unconflicted_state, conflicted_state, state_map_factory ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + + state_map = state_map_factory + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + """ + unconflicted_state = dict(state_sets[0]) + conflicted_state = {} + + full_states = defaultdict( + set, + {k: set((v,)) for k, v in state_sets[0].iteritems()} + ) + + for state_set in state_sets[1:]: + for key, value in state_set.iteritems(): + ls = full_states[key] + if not ls: + ls.add(value) + unconflicted_state[key] = value + elif value not in ls: + ls.add(value) + if len(ls) == 2: + conflicted_state[key] = ls + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +@defer.inlineCallbacks +def _resolve_with_state_fac(unconflicted_state, conflicted_state, + state_map_factory): + needed_events = set( + event_id + for event_ids in conflicted_state.itervalues() + for event_id in event_ids + ) + + state_map = yield state_map_factory(needed_events) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(auth_events.itervalues()) + new_needed_events -= needed_events + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, + state_map): + conflicted_state = { + key: [state_map[ev_id] for ev_id in event_ids] + for key, event_ids in conflicted_state.items() + } auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes + key: state_map[ev_id] + for key, ev_id in auth_events.items() } try: @@ -454,9 +533,10 @@ def resolve_events(state_sets, event_type=None, state_key=""): raise new_state = unconflicted_state - new_state.update(resolved_state) + for key, event in resolved_state.iteritems(): + new_state[key] = event.event_id - return new_state, prev_states + return new_state def _resolve_state_events(conflicted_state, auth_events): @@ -470,11 +550,10 @@ def _resolve_state_events(conflicted_state, auth_events): 4. other events. """ resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = _resolve_auth_events( + resolved_state[POWER_KEY] = _resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -512,14 +591,26 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] - auth_events = dict(auth_events) + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) prev_event = event except AuthError: return prev_event @@ -531,7 +622,7 @@ def _resolve_normal_events(events, auth_events): for event in _ordered_events(events): try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) return event except AuthError: pass diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcb6c5bc3..50e8607c1 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,10 +25,13 @@ from synapse.api.filtering import Filter from synapse.events import FrozenEvent user_localpart = "test_user" -# MockEvent = namedtuple("MockEvent", "sender type room_id") def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 29f068d1f..dfc870066 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) @@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase): def test_minimal(self): self.run_test( - {'type': 'A'}, { 'type': 'A', + 'event_id': '$test:domain', + }, + { + 'type': 'A', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {'age_ts': 20}, @@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'other_key': 'here'}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'C', + 'event_id': '$test:domain', 'content': {'things': 'here'}, }, { 'type': 'C', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain', 'other_field': 'here'}, }, { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain'}, 'signatures': {}, 'unsigned': {}, @@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase): self.assertEquals( self.serialize( MockEvent( + type="foo", + event_id="test", room_id="!foo:bar", content={ "foo": "bar", @@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase): [] ), { + "type": "foo", + "event_id": "test", "room_id": "!foo:bar", "content": { "foo": "bar", From e6153e1bd10529b28b69820decbc039b9d6a1f27 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:21:04 +0000 Subject: [PATCH 28/45] Fix couple of federation state bugs --- synapse/federation/federation_client.py | 6 ++++-- synapse/handlers/federation.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b4bcec77e..c9175bb33 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, builder import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -499,8 +499,10 @@ class FederationClient(FederationBase): if "prev_state" not in pdu_dict: pdu_dict["prev_state"] = [] + ev = builder.EventBuilder(pdu_dict) + defer.returnValue( - (destination, self.event_from_pdu_json(pdu_dict)) + (destination, ev) ) break except CodeMessageException as e: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea89e0cf2..ced5646e9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -596,7 +596,7 @@ class FederationHandler(BaseHandler): preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids ])) - states = dict(zip(event_ids, [s[1] for s in states])) + states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( [e_id for ids in states.values() for e_id in ids], From 633f97151c6c7fa693b3de4addad641186b4ef02 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 13:33:54 +0000 Subject: [PATCH 29/45] Check event is in state_map --- synapse/state.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 294e0c208..df9b6b3cc 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -333,7 +333,7 @@ class StateHandler(object): new_state = yield resolve_events( state_groups_ids.values(), state_map_factory=lambda ev_ids: self.store.get_events( - ev_ids, get_prev_content=False + ev_ids, get_prev_content=False, check_redacted=False, ), ) else: @@ -482,6 +482,8 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, for event_id in event_ids ) + logger.info("Asking for %d conflicted events", len(needed_events)) + state_map = yield state_map_factory(needed_events) auth_events = _create_auth_events_from_maps( @@ -491,6 +493,8 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, new_needed_events = set(auth_events.itervalues()) new_needed_events -= needed_events + logger.info("Asking for %d auth events", len(new_needed_events)) + state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) @@ -515,13 +519,14 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, state_map): conflicted_state = { - key: [state_map[ev_id] for ev_id in event_ids] + key: [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] for key, event_ids in conflicted_state.items() } auth_events = { key: state_map[ev_id] for key, ev_id in auth_events.items() + if ev_id in state_map } try: From ce59a2faad253409a8047ce9302d3d6c087fe812 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:18:53 +0000 Subject: [PATCH 30/45] Correctly handle case of rejected events in state res --- synapse/state.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index df9b6b3cc..d2bd1ad64 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -507,21 +507,27 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma auth_events = {} for event_ids in conflicted_state.itervalues(): for event_id in event_ids: - keys = event_auth.auth_types_for_event(state_map[event_id]) - for key in keys: - if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id return auth_events def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, state_map): - conflicted_state = { - key: [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] - for key, event_ids in conflicted_state.items() - } + new_conflicted_state = {} + for key, event_ids in conflicted_state.iteritems(): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + new_conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state[key] = events[0].event_id + + conflicted_state = new_conflicted_state auth_events = { key: state_map[ev_id] From 04006bb7f014fa62c1534fac7250e7b845fa91d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:31:21 +0000 Subject: [PATCH 31/45] Get state at event rather than for room in push --- synapse/push/push_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index b47bf1f92..a27476bba 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -52,7 +52,7 @@ def get_badge_count(store, user_id): def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + room_state_ids = yield store.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or From e5d2df9c3452617e3390b2c356e11b7c49b022b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:32:53 +0000 Subject: [PATCH 32/45] Use better variable name --- synapse/event_auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 3b7726a52..4096c606f 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -661,13 +661,13 @@ def auth_types_for_event(event): auth_types.append((EventTypes.Create, "", )) if event.type == EventTypes.Member: - e_type = event.content["membership"] - if e_type in [Membership.JOIN, Membership.INVITE]: + membership = event.content["membership"] + if membership in [Membership.JOIN, Membership.INVITE]: auth_types.append((EventTypes.JoinRules, "", )) auth_types.append((EventTypes.Member, event.state_key, )) - if e_type == Membership.INVITE: + if membership == Membership.INVITE: if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, From 37b4c7d8a94203f790c0db408c114ec0004a2cc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:43:32 +0000 Subject: [PATCH 33/45] Fix typo in return type --- synapse/util/caches/descriptors.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index a9ea97fd4..675bfd5fe 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -178,9 +178,8 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) - val = self._pending_deferred_cache.pop(key, None) - if val is not None: - entry_dict, _ = val + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() From a8594fd19f48a179b263d58ba1f9c5ab2f4cb8d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 14:59:03 +0000 Subject: [PATCH 34/45] Use better names --- synapse/state.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index d2bd1ad64..6f62876f8 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -517,21 +517,19 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, state_map): - new_conflicted_state = {} - for key, event_ids in conflicted_state.iteritems(): + conflicted_state = {} + for key, event_ids in conflicted_state_ds.iteritems(): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] if len(events) > 1: - new_conflicted_state[key] = events + conflicted_state[key] = events elif len(events) == 1: - unconflicted_state[key] = events[0].event_id - - conflicted_state = new_conflicted_state + unconflicted_state_ids[key] = events[0].event_id auth_events = { key: state_map[ev_id] - for key, ev_id in auth_events.items() + for key, ev_id in auth_event_ids.items() if ev_id in state_map } @@ -543,7 +541,7 @@ def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, logger.exception("Failed to resolve state") raise - new_state = unconflicted_state + new_state = unconflicted_state_ids for key, event in resolved_state.iteritems(): new_state[key] = event.event_id From c6064a7ba6bae6055dc7960e5eef3956131b718d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 17 Jan 2017 15:23:07 +0000 Subject: [PATCH 35/45] Only construct sets when necessary --- synapse/state.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 6f62876f8..81c6bae73 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -453,22 +453,27 @@ def _seperate(state_sets): unconflicted_state = dict(state_sets[0]) conflicted_state = {} - full_states = defaultdict( - set, - {k: set((v,)) for k, v in state_sets[0].iteritems()} - ) - for state_set in state_sets[1:]: for key, value in state_set.iteritems(): - ls = full_states[key] - if not ls: - ls.add(value) - unconflicted_state[key] = value - elif value not in ls: - ls.add(value) - if len(ls) == 2: - conflicted_state[key] = ls - unconflicted_state.pop(key, None) + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) return unconflicted_state, conflicted_state From ed4d1761525b21989279b99733e415c1c86ed39f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 15:27:28 +0000 Subject: [PATCH 36/45] PEP8 --- synapse/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/state.py b/synapse/state.py index 81c6bae73..15238cd00 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer -from collections import namedtuple, defaultdict +from collections import namedtuple from frozendict import frozendict import logging From 380dba1020294b2c43ffb433b86917d0ee6cf9c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:04:46 +0000 Subject: [PATCH 37/45] Measure metrics of string_cache --- synapse/util/caches/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index ebd715c5d..8a7774a88 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -40,8 +40,8 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) -caches_by_name["string_cache"] = _string_cache +_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) +_stirng_cache_metrics = register_cache("string_cache", _string_cache) KNOWN_KEYS = { @@ -69,7 +69,12 @@ KNOWN_KEYS = { def intern_string(string): """Takes a (potentially) unicode string and interns using custom cache """ - return _string_cache.setdefault(string, string) + new_str = _string_cache.setdefault(string, string) + if new_str is string: + _stirng_cache_metrics.inc_hits() + else: + _stirng_cache_metrics.inc_misses() + return new_str def intern_dict(dictionary): From 5f027d1fc54ab51b420e3deb25d83ac05676fdbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:07:15 +0000 Subject: [PATCH 38/45] Change resolve_state_groups call site logging to DEBUG --- synapse/api/auth.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/state.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 280d4c445..03a215ab1 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -146,7 +146,7 @@ class Auth(object): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from check_host_in_room") + logger.debug("calling resolve_state_groups from check_host_in_room") entry = yield self.state.resolve_state_groups( room_id, latest_event_ids ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc40..2e310fed7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -591,7 +591,7 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - logger.info("calling resolve_state_groups in _maybe_backfill") + logger.debug("calling resolve_state_groups in _maybe_backfill") states = yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids diff --git a/synapse/state.py b/synapse/state.py index 5028b0ac4..1c6e31ac5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -128,7 +128,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state") + logger.debug("calling resolve_state_groups from get_current_state") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -153,7 +153,7 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_state_ids") + logger.debug("calling resolve_state_groups from get_current_state_ids") ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state @@ -167,7 +167,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id, latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.info("calling resolve_state_groups from get_current_user_in_room") + logger.debug("calling resolve_state_groups from get_current_user_in_room") entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -231,7 +231,7 @@ class StateHandler(object): context.prev_state_events = [] defer.returnValue(context) - logger.info("calling resolve_state_groups from compute_event_context") + logger.debug("calling resolve_state_groups from compute_event_context") if event.is_state(): entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], From f878f64f4314dae6bd68b11ad1edbf0883f9bd8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:20:39 +0000 Subject: [PATCH 39/45] Lower the not retrying host log line to debug --- synapse/federation/transaction_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7db7b806d..6b3a7abb9 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -362,7 +362,7 @@ class TransactionQueue(object): if not success: break except NotRetryingDestination: - logger.info( + logger.debug( "TX [%s] not ready for retry yet - " "dropping transaction for now", destination, From 4ec1cf49e20f35bad2d54575fad23c8e21f8d66f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 17:18:13 +0000 Subject: [PATCH 40/45] Lower loading events log to DEBUG --- synapse/storage/events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 04dbdac3f..ca501932f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1084,10 +1084,10 @@ class EventsStore(SQLBaseStore): self._do_fetch ) - logger.info("Loading %d events", len(events)) + logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d - logger.info("Loaded %d events (%d rows)", len(events), len(rows)) + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] From 8c5009b6282b10b2248f080cd9021a799aad5285 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 18 Jan 2017 13:25:56 +0000 Subject: [PATCH 41/45] Lowercase all email addresses before querying db Since we store all emails in the DB in lowercase (https://github.com/matrix-org/synapse/pull/1170) --- synapse/rest/client/v1/login.py | 8 +++++++- synapse/rest/client/v2_alpha/account.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 093bc072f..0c9cdff3b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_password_login(self, login_submission): if 'medium' in login_submission and 'address' in login_submission: + address = login_submission['address'] + if login_submission['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + address = address.lower() user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - login_submission['medium'], login_submission['address'] + login_submission['medium'], address ) if not user_id: raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e74e5e012..398e7f5eb 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet): threepid = result[LoginType.EMAIL_IDENTITY] if 'medium' not in threepid or 'address' not in threepid: raise SynapseError(500, "Malformed threepid") + if threepid['medium'] == 'email': + # For emails, transform the address to lowercase. + # We store all email addreses as lowercase in the DB. + # (See add_threepid in synapse/handlers/auth.py) + threepid['address'] = threepid['address'].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid['medium'], threepid['address'] From c430111d0e6efb6a0f929cc3e10f1ce4f32d2c18 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Jan 2017 14:55:23 +0000 Subject: [PATCH 42/45] Update LruCache size estimate on clear --- synapse/util/caches/lrucache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 072f9a9d1..cf5fbb679 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -189,6 +189,8 @@ class LruCache(object): for cb in node.callbacks: cb() cache.clear() + if size_callback: + cached_cache_len[0] = 0 @synchronized def cache_contains(key): From 1e38be3a7aaea1b6570b27e271855ee380a9129b Mon Sep 17 00:00:00 2001 From: Marvin Steadfast Date: Thu, 19 Jan 2017 14:08:20 +0100 Subject: [PATCH 43/45] Added username and password for turn server It makes it possible to use a turn server that needs a username and password instead of a token. --- synapse/config/voip.py | 4 +++- synapse/rest/client/v1/voip.py | 28 ++++++++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 169980f60..ef9d61adf 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -19,7 +19,9 @@ class VoipConfig(Config): def read_config(self, config): self.turn_uris = config.get("turn_uris", []) - self.turn_shared_secret = config["turn_shared_secret"] + self.turn_shared_secret = config.get("turn_shared_secret") + self.turn_username = config.get("turn_username") + self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) def default_config(self, **kwargs): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index c40442f95..03141c623 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -32,19 +32,27 @@ class VoipRestServlet(ClientV1RestServlet): turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret + turnUsername = self.hs.config.turn_username + turnPassword = self.hs.config.turn_password userLifetime = self.hs.config.turn_user_lifetime - if not turnUris or not turnSecret or not userLifetime: + + if turnUris and turnSecret and userLifetime: + expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 + username = "%d:%s" % (expiry, requester.user.to_string()) + + mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) + # We need to use standard padded base64 encoding here + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. + password = base64.b64encode(mac.digest()) + + elif turnUris and turnUsername and turnPassword and userLifetime: + username = turnUsername + password = turnPassword + + else: defer.returnValue((200, {})) - expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, requester.user.to_string()) - - mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) - # We need to use standard padded base64 encoding here - # encode_base64 because we need to add the standard padding to get the - # same result as the TURN server. - password = base64.b64encode(mac.digest()) - defer.returnValue((200, { 'username': username, 'password': password, From 86e616568793f4b208137ed61add2d5aba9d6c43 Mon Sep 17 00:00:00 2001 From: Marvin Steadfast Date: Thu, 19 Jan 2017 14:35:55 +0100 Subject: [PATCH 44/45] Added default config for turn username and password --- synapse/config/voip.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index ef9d61adf..eeb693027 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -34,6 +34,11 @@ class VoipConfig(Config): # The shared secret used to compute passwords for the TURN server turn_shared_secret: "YOUR_SHARED_SECRET" + # The Username and password if the TURN server needs them and + # does not use a token + #turn_username: "TURNSERVER_USERNAME" + #turn_password: "TURNSERVER_PASSWORD" + # How long generated TURN credentials last turn_user_lifetime: "1h" """ From 97efe99ae964e8f4e866d961282257e6f4293fd8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 11:45:29 +0000 Subject: [PATCH 45/45] Make worker listener config backwards compat --- synapse/config/workers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 904789d15..b165c67ee 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,3 +29,13 @@ class WorkerConfig(Config): self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") self.worker_replication_url = config.get("worker_replication_url") + + if self.worker_listeners: + for listener in self.worker_listeners: + bind_address = listener.pop("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('')