From 61407986b40e28b590961d364f5618bbe7d44e94 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 16:18:46 +0100 Subject: [PATCH 01/34] Add a entry to current_state_resets table when the current state is reset --- synapse/storage/events.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a4b899549..bd4d503b6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -205,6 +205,15 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = event.internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) + self._simple_delete_txn( txn, table="current_state_events", From 1fbb094c6fbaab33ef8e17802e37057e83718e7e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 17:19:56 +0100 Subject: [PATCH 02/34] Add replication streams for ex outliers and current state resets --- synapse/replication/resource.py | 17 +++++- synapse/storage/events.py | 60 ++++++++++++++++++- .../storage/schema/delta/30/state_stream.sql | 38 ++++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 synapse/storage/schema/delta/30/state_stream.sql diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 096a79a7a..7afa1242d 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,7 +204,11 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - events_rows, backfill_rows = yield self.store.get_all_new_events( + ( + events_rows, backfill_rows, + forward_ex_outliers, backward_ex_outliers, + state_resets + ) = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit @@ -215,6 +219,17 @@ class ReplicationResource(Resource): writer.write_header_and_rows("backfill", backfill_rows, ( "position", "internal", "json", "state_group" )) + writer.write_header_and_rows( + "forward_ex_outliers", forward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "backward_ex_outliers", backward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "state_resets", state_resets, ("position",) + ) @defer.inlineCallbacks def presence(self, writer, current_token): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bd4d503b6..9725a3fed 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -323,6 +323,18 @@ class EventsStore(SQLBaseStore): (metadata_json, event.event_id,) ) + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group or context.new_state_group_id + self._simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + } + ) + sql = ( "UPDATE events SET outlier = ?" " WHERE event_id = ?" @@ -1119,8 +1131,34 @@ class EventsStore(SQLBaseStore): if last_forward_id != current_forward_id: txn.execute(sql, (last_forward_id, current_forward_id, limit)) new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT -event_stream_ordering FROM current_state_resets" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + state_resets = txn.fetchall() + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() else: new_forward_events = [] + state_resets = [] + forward_ex_outliers = [] sql = ( "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" @@ -1136,8 +1174,28 @@ class EventsStore(SQLBaseStore): if last_backfill_id != current_backfill_id: txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() else: new_backfill_events = [] + backward_ex_outliers = [] - return (new_forward_events, new_backfill_events) + return ( + new_forward_events, new_backfill_events, + forward_ex_outliers, backward_ex_outliers, + state_resets, + ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql new file mode 100644 index 000000000..706fe1dcf --- /dev/null +++ b/synapse/storage/schema/delta/30/state_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * The positions in the event stream_ordering when the current_state was + * replaced by the state at the event. + */ + +CREATE TABLE IF NOT EXISTS current_state_resets( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL +); + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); From c27c51484a306caf3946f205a933878c563d3d8a Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 31 Mar 2016 10:12:31 +0100 Subject: [PATCH 03/34] Don't ignore the obey overlay if the rule has an enabled attribute of False Fixes https://github.com/vector-im/vector-web/issues/1244 --- synapse/push/push_rule_evaluator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 51f73a5b7..c3c287762 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -133,8 +133,9 @@ class PushRuleEvaluator: enabled = self.enabled_map.get(r['rule_id'], None) if enabled is not None and not enabled: continue - - if not r.get("enabled", True): + elif enabled is None and not r.get("enabled", True): + # if no override, check enabled on the rule itself + # (may have come from a base rule) continue conditions = r['conditions'] From 2ec54260350b46c937527bd566b713cf3544f1d2 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 10:33:02 +0100 Subject: [PATCH 04/34] Use a namedtuple rather than tuple unpacking --- synapse/replication/resource.py | 16 ++++++---------- synapse/storage/events.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 7afa1242d..69afcb03d 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,31 +204,27 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - ( - events_rows, backfill_rows, - forward_ex_outliers, backward_ex_outliers, - state_resets - ) = yield self.store.get_all_new_events( + res = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows("events", events_rows, ( + writer.write_header_and_rows("events", res.new_forward_events, ( "position", "internal", "json", "state_group" )) - writer.write_header_and_rows("backfill", backfill_rows, ( + writer.write_header_and_rows("backfill", res.new_backfill_events, ( "position", "internal", "json", "state_group" )) writer.write_header_and_rows( - "forward_ex_outliers", forward_ex_outliers, + "forward_ex_outliers", res.forward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "backward_ex_outliers", backward_ex_outliers, + "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "state_resets", state_resets, ("position",) + "state_resets", res.state_resets, ("position",) ) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9725a3fed..b7ad045e4 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json from contextlib import contextmanager - +from collections import namedtuple import logging import math @@ -1193,9 +1193,16 @@ class EventsStore(SQLBaseStore): new_backfill_events = [] backward_ex_outliers = [] - return ( + return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + + +AllNewEventsResult = namedtuple("AllNewEventsResult", [ + "new_forward_events", "new_backfill_events", + "forward_ex_outliers", "backward_ex_outliers", + "state_resets" +]) From 5260db7663ff242e8a0adcbdfeeaf0f6e7ba1e96 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:49:16 +0100 Subject: [PATCH 05/34] Line length --- synapse/handlers/room.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 71f7ab3d2..a230dc37f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -536,7 +536,9 @@ class RoomMemberHandler(BaseHandler): if not is_host_in_room: # perhaps we've been invited - inviter = self.get_inviter(target_user.to_string(), context.current_state) + inviter = self.get_inviter( + target_user.to_string(), context.current_state + ) if not inviter: raise SynapseError(404, "Not a known room") From d35780eda02935165f38f3fcbf8843f40026eeaa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 13:08:45 +0100 Subject: [PATCH 06/34] Split out RoomMemberHandler --- synapse/handlers/__init__.py | 3 +- synapse/handlers/room.py | 605 +----------------------------- synapse/handlers/room_member.py | 646 ++++++++++++++++++++++++++++++++ 3 files changed, 651 insertions(+), 603 deletions(-) create mode 100644 synapse/handlers/room_member.py diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 66d2c0112..f4dbf47c1 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomListHandler, RoomContextHandler, ) +from .room_member import RoomMemberHandler from .message import MessageHandler from .events import EventStreamHandler, EventHandler from .federation import FederationHandler diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a230dc37f..ee99ded21 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,20 +18,16 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, Membership, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset, ) -from synapse.api.errors import AuthError, StoreError, SynapseError, Codes +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - from collections import OrderedDict -from unpaddedbase64 import decode_base64 import logging import math @@ -357,601 +353,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomMemberHandler(BaseHandler): - # TODO(paul): This handler currently contains a messy conflation of - # low-level API that works on UserID objects and so on, and REST-level - # API that takes ID strings and returns pagination chunks. These concerns - # ought to be separated out a lot better. - - def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) - - self.clock = hs.get_clock() - - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") - - @defer.inlineCallbacks - def get_room_members(self, room_id): - users = yield self.store.get_users_in_room(room_id) - - defer.returnValue([UserID.from_string(u) for u in users]) - - @defer.inlineCallbacks - def fetch_room_distributions_into(self, room_id, localusers=None, - remotedomains=None, ignore_user=None): - """Fetch the distribution of a room, adding elements to either - 'localusers' or 'remotedomains', which should be a set() if supplied. - If ignore_user is set, ignore that user. - - This function returns nothing; its result is performed by the - side-effect on the two passed sets. This allows easy accumulation of - member lists of multiple rooms at once if required. - """ - members = yield self.get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if self.hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - @defer.inlineCallbacks - def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - ): - effective_membership_state = action - if action in ["kick", "unban"]: - effective_membership_state = "leave" - - if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( - third_party_signed["sender"], - target.to_string(), - room_id, - third_party_signed, - ) - - msg_handler = self.hs.get_handlers().message_handler - - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" - - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - - old_state = context.current_state.get((EventTypes.Member, event.state_key)) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was is banned" % (action,), - errcode=Codes.BAD_STATE - ) - - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, - ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, - ) - - @defer.inlineCallbacks - def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, - ): - """ - Change the membership status of a user in a room. - - Args: - requester (Requester): The local user who requested the membership - event. If None, certain checks, like whether this homeserver can - act as the sender, will be skipped. - event (SynapseEvent): The membership event. - context: The context of the event. - is_guest (bool): Whether the sender is a guest. - room_hosts ([str]): Homeservers which are likely to already be in - the room, and could be danced with in order to join this - homeserver for the first time. - ratelimit (bool): Whether to rate limit this request. - Raises: - SynapseError if there was a problem changing the membership. - """ - remote_room_hosts = remote_room_hosts or [] - - target_user = UserID.from_string(event.state_key) - room_id = event.room_id - - if requester is not None: - sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) - assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - else: - requester = Requester(target_user, None, False) - - message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) - if prev_event is not None: - return - - action = "send" - - if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter( - target_user.to_string(), context.current_state - ) - if not inviter: - raise SynapseError(404, "Not a known room") - - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) - - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), - None - ) - - if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - yield user_joined_room(self.distributor, target_user, room_id) - elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) - - def _can_guest_join(self, current_state): - """ - Returns whether a guest can join a room based on its current state. - """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - - def _should_do_dance(self, context, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(context.current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): - """ - Get the room ID associated with a room alias. - - Args: - room_alias (RoomAlias): The alias to look up. - Returns: - A tuple of: - The room ID as a RoomID object. - Hosts likely to be participating in the room ([str]). - Raises: - SynapseError if room alias could not be found. - """ - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if not mapping: - raise SynapseError(404, "No such room alias") - - room_id = mapping["room_id"] - servers = mapping["servers"] - - defer.returnValue((RoomID.from_string(room_id), servers)) - - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None - - @defer.inlineCallbacks - def get_joined_rooms_for_user(self, user): - """Returns a list of roomids that the user has any of the given - membership states in.""" - - rooms = yield self.store.get_rooms_for_user( - user.to_string(), - ) - - # For some reason the list of events contains duplicates - # TODO(paul): work out why because I really don't think it should - room_ids = set(r.room_id for r in rooms) - - defer.returnValue(room_ids) - - @defer.inlineCallbacks - def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id - ): - invitee = yield self._lookup_3pid( - id_server, medium, address - ) - - if invitee: - handler = self.hs.get_handlers().room_member_handler - yield handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, - ) - else: - yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id - ) - - @defer.inlineCallbacks - def _lookup_3pid(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. - """ - try: - data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) - defer.returnValue(data["mxid"]) - - except IOError as e: - logger.warn("Error from identity server lookup: %s" % (e,)) - defer.returnValue(None) - - @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), - ) - if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) - ) - return - - @defer.inlineCallbacks - def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id - ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) - - inviter_display_name = "" - inviter_avatar_url = "" - member_event = room_state.get((EventTypes.Member, user.to_string())) - if member_event: - inviter_display_name = member_event.content.get("displayname", "") - inviter_avatar_url = member_event.content.get("avatar_url", "") - - canonical_room_alias = "" - canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) - if canonical_alias_event: - canonical_room_alias = canonical_alias_event.content.get("alias", "") - - room_name = "" - room_name_event = room_state.get((EventTypes.Name, "")) - if room_name_event: - room_name = room_name_event.content.get("name", "") - - room_join_rules = "" - join_rules_event = room_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - room_join_rules = join_rules_event.content.get("join_rule", "") - - room_avatar_url = "" - room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if room_avatar_event: - room_avatar_url = room_avatar_event.content.get("url", "") - - token, public_keys, fallback_public_key, display_name = ( - yield self._ask_id_server_for_third_party_invite( - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url - ) - ) - - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.ThirdPartyInvite, - "content": { - "display_name": display_name, - "public_keys": public_keys, - - # For backwards compatibility: - "key_validity_url": fallback_public_key["key_validity_url"], - "public_key": fallback_public_key["public_key"], - }, - "room_id": room_id, - "sender": user.to_string(), - "state_key": token, - }, - txn_id=txn_id, - ) - - @defer.inlineCallbacks - def _ask_id_server_for_third_party_invite( - self, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url - ): - """ - Asks an identity server for a third party invite. - - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. - """ - - is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, - ) - - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( - medium=medium, - address=address, - inviter_user_id=inviter_user_id, - ) - - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), - }) - - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( - is_url, - invite_config - ) - # TODO: Check for success - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - defer.returnValue((token, public_keys, fallback_public_key, display_name)) - - @defer.inlineCallbacks - def forget(self, user, room_id): - user_id = user.to_string() - - member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) - membership = member.membership if member else None - - if membership is not None and membership != Membership.LEAVE: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) - - if membership: - yield self.store.forget(user_id, room_id) - - class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py new file mode 100644 index 000000000..5fdbd3adc --- /dev/null +++ b/synapse/handlers/room_member.py @@ -0,0 +1,646 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.types import UserID, RoomID, Requester +from synapse.api.constants import ( + EventTypes, Membership, +) +from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.util.logcontext import preserve_context_over_fn + +from signedjson.sign import verify_signed_json +from signedjson.key import decode_verify_key_bytes + +from unpaddedbase64 import decode_base64 + +import logging + +logger = logging.getLogger(__name__) + +id_server_scheme = "https://" + + +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + +class RoomMemberHandler(BaseHandler): + # TODO(paul): This handler currently contains a messy conflation of + # low-level API that works on UserID objects and so on, and REST-level + # API that takes ID strings and returns pagination chunks. These concerns + # ought to be separated out a lot better. + + def __init__(self, hs): + super(RoomMemberHandler, self).__init__(hs) + + self.clock = hs.get_clock() + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def get_room_members(self, room_id): + users = yield self.store.get_users_in_room(room_id) + + defer.returnValue([UserID.from_string(u) for u in users]) + + @defer.inlineCallbacks + def fetch_room_distributions_into(self, room_id, localusers=None, + remotedomains=None, ignore_user=None): + """Fetch the distribution of a room, adding elements to either + 'localusers' or 'remotedomains', which should be a set() if supplied. + If ignore_user is set, ignore that user. + + This function returns nothing; its result is performed by the + side-effect on the two passed sets. This allows easy accumulation of + member lists of multiple rooms at once if required. + """ + members = yield self.get_room_members(room_id) + for member in members: + if ignore_user is not None and member == ignore_user: + continue + + if self.hs.is_mine(member): + if localusers is not None: + localusers.add(member) + else: + if remotedomains is not None: + remotedomains.add(member.domain) + + @defer.inlineCallbacks + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + ): + effective_membership_state = action + if action in ["kick", "unban"]: + effective_membership_state = "leave" + + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + + msg_handler = self.hs.get_handlers().message_handler + + content = {"membership": effective_membership_state} + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": effective_membership_state, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + + old_state = context.current_state.get((EventTypes.Member, event.state_key)) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was is banned" % (action,), + errcode=Codes.BAD_STATE + ) + + member_handler = self.hs.get_handlers().room_member_handler + yield member_handler.send_membership_event( + requester, + event, + context, + ratelimit=ratelimit, + remote_room_hosts=remote_room_hosts, + ) + + @defer.inlineCallbacks + def send_membership_event( + self, + requester, + event, + context, + remote_room_hosts=None, + ratelimit=True, + ): + """ + Change the membership status of a user in a room. + + Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + Raises: + SynapseError if there was a problem changing the membership. + """ + remote_room_hosts = remote_room_hosts or [] + + target_user = UserID.from_string(event.state_key) + room_id = event.room_id + + if requester is not None: + sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = Requester(target_user, None, False) + + message_handler = self.hs.get_handlers().message_handler + prev_event = message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return + + action = "send" + + if event.membership == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(context.current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + do_remote_join_dance, remote_room_hosts = self._should_do_dance( + context, + (self.get_inviter(event.state_key, context.current_state)), + remote_room_hosts, + ) + if do_remote_join_dance: + action = "remote_join" + elif event.membership == Membership.LEAVE: + is_host_in_room = self.is_host_in_room(context.current_state) + + if not is_host_in_room: + # perhaps we've been invited + inviter = self.get_inviter( + target_user.to_string(), context.current_state + ) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + action = "remote_reject" + + federation_handler = self.hs.get_handlers().federation_handler + + if action == "remote_join": + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield federation_handler.do_invite_join( + remote_room_hosts, + event.room_id, + event.user_id, + event.content, + ) + elif action == "remote_reject": + yield federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + event.user_id + ) + else: + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) + + prev_member_event = context.current_state.get( + (EventTypes.Member, target_user.to_string()), + None + ) + + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + def _can_guest_join(self, current_state): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access = current_state.get((EventTypes.GuestAccess, ""), None) + return ( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) + + def _should_do_dance(self, context, inviter, room_hosts=None): + # TODO: Shouldn't this be remote_room_host? + room_hosts = room_hosts or [] + + is_host_in_room = self.is_host_in_room(context.current_state) + if is_host_in_room: + return False, room_hosts + + if inviter and not self.hs.is_mine(inviter): + room_hosts.append(inviter.domain) + + return True, room_hosts + + @defer.inlineCallbacks + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. + + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if not mapping: + raise SynapseError(404, "No such room alias") + + room_id = mapping["room_id"] + servers = mapping["servers"] + + defer.returnValue((RoomID.from_string(room_id), servers)) + + def get_inviter(self, user_id, current_state): + prev_state = current_state.get((EventTypes.Member, user_id)) + if prev_state and prev_state.membership == Membership.INVITE: + return UserID.from_string(prev_state.user_id) + return None + + @defer.inlineCallbacks + def get_joined_rooms_for_user(self, user): + """Returns a list of roomids that the user has any of the given + membership states in.""" + + rooms = yield self.store.get_rooms_for_user( + user.to_string(), + ) + + # For some reason the list of events contains duplicates + # TODO(paul): work out why because I really don't think it should + room_ids = set(r.room_id for r in rooms) + + defer.returnValue(room_ids) + + @defer.inlineCallbacks + def do_3pid_invite( + self, + room_id, + inviter, + medium, + address, + id_server, + requester, + txn_id + ): + invitee = yield self._lookup_3pid( + id_server, medium, address + ) + + if invitee: + handler = self.hs.get_handlers().room_member_handler + yield handler.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + txn_id=txn_id, + ) + else: + yield self._make_and_store_3pid_invite( + requester, + id_server, + medium, + address, + room_id, + inviter, + txn_id=txn_id + ) + + @defer.inlineCallbacks + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + (str) the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), + { + "medium": medium, + "address": address, + } + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + self.verify_any_signature(data, id_server) + defer.returnValue(data["mxid"]) + + except IOError as e: + logger.warn("Error from identity server lookup: %s" % (e,)) + defer.returnValue(None) + + @defer.inlineCallbacks + def verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + key_data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/pubkey/%s" % + (id_server_scheme, server_hostname, key_name,), + ) + if "public_key" not in key_data: + raise AuthError(401, "No public key named %s from %s" % + (key_name, server_hostname,)) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + ) + return + + @defer.inlineCallbacks + def _make_and_store_3pid_invite( + self, + requester, + id_server, + medium, + address, + room_id, + user, + txn_id + ): + room_state = yield self.hs.get_state_handler().get_current_state(room_id) + + inviter_display_name = "" + inviter_avatar_url = "" + member_event = room_state.get((EventTypes.Member, user.to_string())) + if member_event: + inviter_display_name = member_event.content.get("displayname", "") + inviter_avatar_url = member_event.content.get("avatar_url", "") + + canonical_room_alias = "" + canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) + if canonical_alias_event: + canonical_room_alias = canonical_alias_event.content.get("alias", "") + + room_name = "" + room_name_event = room_state.get((EventTypes.Name, "")) + if room_name_event: + room_name = room_name_event.content.get("name", "") + + room_join_rules = "" + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + room_join_rules = join_rules_event.content.get("join_rule", "") + + room_avatar_url = "" + room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if room_avatar_event: + room_avatar_url = room_avatar_event.content.get("url", "") + + token, public_keys, fallback_public_key, display_name = ( + yield self._ask_id_server_for_third_party_invite( + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url + ) + ) + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.ThirdPartyInvite, + "content": { + "display_name": display_name, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], + }, + "room_id": room_id, + "sender": user.to_string(), + "state_key": token, + }, + txn_id=txn_id, + ) + + @defer.inlineCallbacks + def _ask_id_server_for_third_party_invite( + self, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url + ): + """ + Asks an identity server for a third party invite. + + :param id_server (str): hostname + optional port for the identity server. + :param medium (str): The literal string "email". + :param address (str): The third party address being invited. + :param room_id (str): The ID of the room to which the user is invited. + :param inviter_user_id (str): The user ID of the inviter. + :param room_alias (str): An alias for the room, for cosmetic + notifications. + :param room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + :param room_join_rules (str): The join rules of the email + (e.g. "public"). + :param room_name (str): The m.room.name of the room. + :param inviter_display_name (str): The current display name of the + inviter. + :param inviter_avatar_url (str): The URL of the inviter's avatar. + + :return: A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( + id_server_scheme, id_server, + ) + + invite_config = { + "medium": medium, + "address": address, + "room_id": room_id, + "room_alias": room_alias, + "room_avatar_url": room_avatar_url, + "room_join_rules": room_join_rules, + "room_name": room_name, + "sender": inviter_user_id, + "sender_display_name": inviter_display_name, + "sender_avatar_url": inviter_avatar_url, + } + + if self.hs.config.invite_3pid_guest: + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + + guest_user_info = yield self.hs.get_auth().get_user_by_access_token( + guest_access_token + ) + + invite_config.update({ + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_info["user"].to_string(), + }) + + data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + is_url, + invite_config + ) + # TODO: Check for success + token = data["token"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) + display_name = data["display_name"] + defer.returnValue((token, public_keys, fallback_public_key, display_name)) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) From 76503f95ed2b618a3ff97c6b04868d8e440ef9d4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:00:42 +0100 Subject: [PATCH 07/34] Remove the is_new_state argument to persist event. Move the checks for whether an event is new state inside persist event itself. This was harder than expected because there wasn't enough information passed to persist event to correctly handle invites from remote servers for new rooms. --- synapse/events/__init__.py | 3 ++ synapse/handlers/federation.py | 20 ++------ synapse/storage/events.py | 88 +++++++++++++++++++--------------- 3 files changed, 56 insertions(+), 55 deletions(-) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 925a83c64..13154b172 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -33,6 +33,9 @@ class _EventInternalMetadata(object): def is_outlier(self): return getattr(self, "outlier", False) + def is_invite_from_remote(self): + return getattr(self, "invite_from_remote", False) + def _event_dict_property(key): def getter(self): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 267fedf11..4a35344d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -102,8 +102,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, state=None, - auth_chain=None): + def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -174,11 +173,7 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events(origin, event_infos) try: context, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -761,6 +756,7 @@ class FederationHandler(BaseHandler): event = pdu event.internal_metadata.outlier = True + event.internal_metadata.invite_from_remote = True event.signatures.update( compute_event_signature( @@ -1069,9 +1065,6 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None): - - outlier = event.internal_metadata.is_outlier() - context = yield self._prep_event( origin, event, state=state, @@ -1087,14 +1080,12 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - is_new_state=not outlier, ) defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False, - outliers=False): + def _handle_new_events(self, origin, event_infos, backfilled=False): contexts = yield defer.gatherResults( [ self._prep_event( @@ -1113,7 +1104,6 @@ class FederationHandler(BaseHandler): for ev_info, context in itertools.izip(event_infos, contexts) ], backfilled=backfilled, - is_new_state=(not outliers and not backfilled), ) @defer.inlineCallbacks @@ -1176,7 +1166,6 @@ class FederationHandler(BaseHandler): (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ], - is_new_state=False, ) new_event_context = yield self.state_handler.compute_event_context( @@ -1185,7 +1174,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - is_new_state=True, current_state=state, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index dc3e994de..b5f9b3b90 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -61,8 +61,7 @@ class EventsStore(SQLBaseStore): ) @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False, - is_new_state=True): + def persist_events(self, events_and_contexts, backfilled=False): if not events_and_contexts: return @@ -110,13 +109,11 @@ class EventsStore(SQLBaseStore): self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, - is_new_state=is_new_state, ) @defer.inlineCallbacks @log_function - def persist_event(self, event, context, - is_new_state=True, current_state=None): + def persist_event(self, event, context, current_state=None): try: with self._stream_id_gen.get_next() as stream_ordering: @@ -128,7 +125,6 @@ class EventsStore(SQLBaseStore): self._persist_event_txn, event=event, context=context, - is_new_state=is_new_state, current_state=current_state, ) except _RollbackButIsFineException: @@ -194,8 +190,7 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @log_function - def _persist_event_txn(self, txn, event, context, - is_new_state, current_state): + def _persist_event_txn(self, txn, event, context, current_state): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -236,12 +231,10 @@ class EventsStore(SQLBaseStore): txn, [(event, context)], backfilled=False, - is_new_state=is_new_state, ) @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - is_new_state): + def _persist_events_txn(self, txn, events_and_contexts, backfilled): depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -452,10 +445,9 @@ class EventsStore(SQLBaseStore): txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = filter( - lambda i: i[0].is_state(), - events_and_contexts, - ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] state_values = [] for event, context in state_events_and_contexts: @@ -493,32 +485,50 @@ class EventsStore(SQLBaseStore): ], ) - if is_new_state: - for event, _ in state_events_and_contexts: - if not context.rejected: - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) + for event, _ in state_events_and_contexts: + if backfilled: + # Backfilled events come before the current state so shouldn't + # clobber it. + continue - if event.type in [EventTypes.Name, EventTypes.Aliases]: - txn.call_after( - self.get_room_name_and_aliases.invalidate, - (event.room_id,) - ) + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + # Outlier events generally shouldn't clobber the current state. + # However invites from remote severs for rooms we aren't in + # are a bit special: they don't come with any associated + # state so are technically an outlier, however all the + # client-facing code assumes that they are in the current + # state table so we insert the event anyway. + continue - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) + if context.rejected: + # If the event failed it's auth checks then it shouldn't + # clobbler the current state. + continue + + txn.call_after( + self._get_current_state_for_key.invalidate, + (event.room_id, event.type, event.state_key,) + ) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + txn.call_after( + self.get_room_name_and_aliases.invalidate, + (event.room_id,) + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } + ) return From 5d069291696c62c94d795506dd5a22a3dbd019e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:09:09 +0100 Subject: [PATCH 08/34] Move the check for backfilled outside the for loop --- synapse/storage/events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b5f9b3b90..83279d65f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -485,12 +485,12 @@ class EventsStore(SQLBaseStore): ], ) - for event, _ in state_events_and_contexts: - if backfilled: - # Backfilled events come before the current state so shouldn't - # clobber it. - continue + if backfilled: + # Backfilled events come before the current state so we don't need + # to update the current state table + return + for event, _ in state_events_and_contexts: if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): # Outlier events generally shouldn't clobber the current state. From dc4c1579d48b7db8264a935be3e955b779c78ab6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:32:24 +0100 Subject: [PATCH 09/34] Remove outlier parameter from compute_event_context Use event.internal_metadata.is_outlier instead. --- synapse/handlers/_base.py | 3 +-- synapse/handlers/federation.py | 11 ++++------- synapse/state.py | 4 ++-- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb..d407eaeee 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -261,8 +261,7 @@ class BaseHandler(object): context = yield state_handler.compute_event_context( builder, - old_state=(prev_member_event,), - outlier=True + old_state=(prev_member_event,) ) if builder.is_state(): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d3..4049c01d2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1118,11 +1118,9 @@ class FederationHandler(BaseHandler): """ events_to_context = {} for e in itertools.chain(auth_events, state): - ctx = yield self.state_handler.compute_event_context( - e, outlier=True, - ) - events_to_context[e.event_id] = ctx e.internal_metadata.outlier = True + ctx = yield self.state_handler.compute_event_context(e) + events_to_context[e.event_id] = ctx event_map = { e.event_id: e @@ -1169,7 +1167,7 @@ class FederationHandler(BaseHandler): ) new_event_context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=False, + event, old_state=state ) event_stream_id, max_stream_id = yield self.store.persist_event( @@ -1181,10 +1179,9 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): - outlier = event.internal_metadata.is_outlier() context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=outlier, + event, old_state=state, ) if not auth_events: diff --git a/synapse/state.py b/synapse/state.py index 41d32e664..4672ada1b 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -100,7 +100,7 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None, outlier=False): + def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph just before the event - i.e. it never includes `event` @@ -115,7 +115,7 @@ class StateHandler(object): """ context = EventContext() - if outlier: + if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. From 7753fc65702dc2b98885f462fe0e6ac9f8e27b06 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 10:34:51 +0100 Subject: [PATCH 10/34] Fix the invalidation of the names and aliases cache --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65f..7468e6e00 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -198,7 +198,7 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) - txn.call_after(self.get_room_name_and_aliases, event.room_id) + txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state From 35bb465b8698cb3ed9d034563d3b9ee03579e775 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 13:10:07 +0100 Subject: [PATCH 11/34] Filter rooms list before chunking --- synapse/handlers/sync.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e..06098f899 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -252,6 +252,18 @@ class SyncHandler(BaseHandler): archived = [] deferreds = [] + user_id = sync_config.user.to_string() + + def _should_include_room(event): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + return False + return True + + room_list = filter(_should_include_room, room_list) + room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] for room_list_chunk in room_list_chunks: for event in room_list_chunk: @@ -276,12 +288,6 @@ class SyncHandler(BaseHandler): invite=invite, )) elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if sync_config.user.to_string() == event.sender: - continue - leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) From e36bfbab38def70e0fcc1bafcecb6e666dbbc1ad Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:29:05 +0100 Subject: [PATCH 12/34] Use a stream id generator for backfilled ids --- synapse/storage/__init__.py | 20 +++------ synapse/storage/account_data.py | 4 +- synapse/storage/events.py | 19 +++------ synapse/storage/presence.py | 6 ++- synapse/storage/push_rule.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 6 +-- synapse/storage/state.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 6 +-- synapse/storage/util/id_generators.py | 61 ++++++++++++++++++--------- 11 files changed, 69 insertions(+), 61 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index aaad38039..f87e907cd 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore, self.hs = hs self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering" ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", direction=-1 + ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) @@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) @@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe21..7a7fbf1e5 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65f..4ab23c159 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,6 @@ from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from contextlib import contextmanager from collections import namedtuple import logging @@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) @@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e31..59b4ef5ce 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5..d2bf7f2ae 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac577..d1669c778 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers_txn(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b9d848ea..4befebc8e 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cached(num_args=2) @@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f26..864483065 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -458,4 +458,4 @@ class StateStore(SQLBaseStore): ) def get_state_stream_token(self): - return self._state_groups_id_gen.get_max_token() + return self._state_groups_id_gen.get_current_token() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cf84938be..76bcd9cd0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() + token = yield self._stream_id_gen.get_current_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b3..9da23f34c 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d5..03f2aa6a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, direction=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if direction == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else direction + return (max if direction == 1 else min)(current_id, direction) class StreamIdGenerator(object): @@ -45,17 +49,30 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + :param connection db_conn: A database connection to use to fetch the + initial value of the generator from. + :param str table: A database table to read the initial value of the id + generator from. + :param str column: The column of the database table to read the initial + value from the id generator from. + :param list extra_tables: List of pairs of database tables and columns to + use to source the initial value of the generator from. The value with + the largest magnitude is used. + :param int direction: which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], direction=1): self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._direction = direction + self._current = _load_current_id(db_conn, table, column, direction) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if direction > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, direction) ) self._unfinished_ids = deque() @@ -66,8 +83,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._direction + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +105,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._direction, + self._current + self._direction * (n + 1), + self._direction + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +126,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._direction - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +146,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +158,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +172,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +181,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) From a2866e2e6a8fa60a538a98f62e1733ab062020aa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:50:54 +0100 Subject: [PATCH 13/34] Rename direction to step, apply checks consistently --- synapse/storage/__init__.py | 2 +- synapse/storage/util/id_generators.py | 30 +++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f87e907cd..57863bba4 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering" ) self._backfill_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", direction=-1 + db_conn, "events", "stream_ordering", step=-1 ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 03f2aa6a5..310b7dc6e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -29,16 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_current_id(db_conn, table, column, direction=1): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - if direction == 1: + if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - current_id = int(val) if val else direction - return (max if direction == 1 else min)(current_id, direction) + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -58,21 +58,21 @@ class StreamIdGenerator(object): :param list extra_tables: List of pairs of database tables and columns to use to source the initial value of the generator from. The value with the largest magnitude is used. - :param int direction: which direction the stream ids grow in. +1 to grow + :param int step: which direction the stream ids grow in. +1 to grow upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], direction=1): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): self._lock = threading.Lock() - self._direction = direction - self._current = _load_current_id(db_conn, table, column, direction) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current = (max if direction > 0 else min)( + self._current = (max if step > 0 else min)( self._current, - _load_current_id(db_conn, table, column, direction) + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -83,7 +83,7 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current += self._direction + self._current += self._step next_id = self._current self._unfinished_ids.append(next_id) @@ -106,9 +106,9 @@ class StreamIdGenerator(object): """ with self._lock: next_ids = range( - self._current + self._direction, - self._current + self._direction * (n + 1), - self._direction + self._current + self._step, + self._current + self._step * (n + 1), + self._step ) self._current += n @@ -132,7 +132,7 @@ class StreamIdGenerator(object): """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._direction + return self._unfinished_ids[0] - self._step return self._current From 8d73cd502bd8ee6903c81f20f79fe5e1509692e3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 14:06:00 +0100 Subject: [PATCH 14/34] Add concurrently_execute function --- synapse/handlers/message.py | 10 +--- synapse/handlers/room.py | 17 +++---- synapse/handlers/sync.py | 96 ++++++++++++++++--------------------- synapse/util/async.py | 32 ++++++++++++- 4 files changed, 81 insertions(+), 74 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5c50c611b..0bb111d04 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -556,14 +557,7 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") - # Only do N rooms at once - n = 5 - d_list = [handle_room(e) for e in room_list] - for i in range(0, len(d_list), n): - yield defer.gatherResults( - d_list[i:i + n], - consumeErrors=True - ).addErrback(unwrapFirstError) + yield concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee99ded21..3e1d9282d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -23,7 +23,8 @@ from synapse.api.constants import ( EventTypes, JoinRules, RoomCreationPreset, ) from synapse.api.errors import AuthError, StoreError, SynapseError -from synapse.util import stringutils, unwrapFirstError +from synapse.util import stringutils +from synapse.util.async import concurrently_execute from synapse.util.logcontext import preserve_context_over_fn from synapse.util.caches.response_cache import ResponseCache @@ -368,6 +369,8 @@ class RoomListHandler(BaseHandler): def _get_public_room_list(self): room_ids = yield self.store.get_public_room_ids() + results = [] + @defer.inlineCallbacks def handle_room(room_id): aliases = yield self.store.get_aliases_for_room(room_id) @@ -428,18 +431,12 @@ class RoomListHandler(BaseHandler): joined_users = yield self.store.get_users_in_room(room_id) result["num_joined_members"] = len(joined_users) - defer.returnValue(result) + results.append(result) - result = [] - for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)): - chunk_result = yield defer.gatherResults([ - handle_room(room_id) - for room_id in chunk - ], consumeErrors=True).addErrback(unwrapFirstError) - result.extend(v for v in chunk_result if v) + yield concurrently_execute(handle_room, room_ids, 10) # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": result}) + defer.returnValue({"start": "START", "end": "END", "chunk": results}) class RoomContextHandler(BaseHandler): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 06098f899..e38fe1ef9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,8 +17,8 @@ from ._base import BaseHandler from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes -from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.async import concurrently_execute +from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -250,64 +250,50 @@ class SyncHandler(BaseHandler): joined = [] invited = [] archived = [] - deferreds = [] user_id = sync_config.user.to_string() - def _should_include_room(event): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if user_id == event.sender: - return False - return True + @defer.inlineCallbacks + def _generate_room_entry(event): + if event.membership == Membership.JOIN: + room_result = yield self.full_state_sync_for_joined_room( + room_id=event.room_id, + sync_config=sync_config, + now_token=now_token, + timeline_since_token=timeline_since_token, + ephemeral_by_room=ephemeral_by_room, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) + joined.append(room_result) + elif event.membership == Membership.INVITE: + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + return - room_list = filter(_should_include_room, room_list) + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_result = yield self.full_state_sync_for_archived_room( + sync_config=sync_config, + room_id=event.room_id, + leave_event_id=event.event_id, + leave_token=leave_token, + timeline_since_token=timeline_since_token, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) + archived.append(room_result) - room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] - for room_list_chunk in room_list_chunks: - for event in room_list_chunk: - if event.membership == Membership.JOIN: - room_sync_deferred = preserve_fn( - self.full_state_sync_for_joined_room - )( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(joined.append) - deferreds.append(room_sync_deferred) - elif event.membership == Membership.INVITE: - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_sync_deferred = preserve_fn( - self.full_state_sync_for_archived_room - )( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(archived.append) - deferreds.append(room_sync_deferred) - - yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) + yield concurrently_execute(_generate_room_entry, room_list, 10) account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) diff --git a/synapse/util/async.py b/synapse/util/async.py index 640fae389..a75e1c71f 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,7 +16,8 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import PreserveLoggingContext, preserve_fn +from synapse.util import unwrapFirstError @defer.inlineCallbacks @@ -107,3 +108,32 @@ class ObservableDeferred(object): return "" % ( id(self), self._result, self._deferred, ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(it.next()) + except StopIteration: + pass + + return defer.gatherResults([ + preserve_fn(_concurrently_execute_inner)() + for _ in xrange(limit) + ], consumeErrors=True).addErrback(unwrapFirstError) From 3f4eb4c92402d80d0f41501bf71a60a1b94f2756 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 14:15:27 +0100 Subject: [PATCH 15/34] Comment --- synapse/util/async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/util/async.py b/synapse/util/async.py index a75e1c71f..cd4d90f3c 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -121,7 +121,7 @@ def concurrently_execute(func, args, limit): limit (int): Maximum number of conccurent executions. Returns: - deferred + deferred: Resolved when all function invocations have finished. """ it = iter(args) From 35b5c4ba1b1892fde18f531c96d71aa58de649e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:07:01 +0100 Subject: [PATCH 16/34] use google style doc strings --- synapse/storage/util/id_generators.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 310b7dc6e..58ea54cf6 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -49,17 +49,18 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. - :param connection db_conn: A database connection to use to fetch the - initial value of the generator from. - :param str table: A database table to read the initial value of the id - generator from. - :param str column: The column of the database table to read the initial - value from the id generator from. - :param list extra_tables: List of pairs of database tables and columns to - use to source the initial value of the generator from. The value with - the largest magnitude is used. - :param int step: which direction the stream ids grow in. +1 to grow - upwards, -1 to grow downwards. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: From 9bc5b4c663ed4bb35ef74a820c108765c7ca0f67 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:08:20 +0100 Subject: [PATCH 17/34] Assert that the step != 0 --- synapse/storage/util/id_generators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 58ea54cf6..f69f1cdad 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -67,6 +67,7 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() self._step = step self._current = _load_current_id(db_conn, table, column, step) From 2a37467fa1358eb41513893efe44cbd294dca36c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 16:08:59 +0100 Subject: [PATCH 18/34] Use google style doc strings. pycharm supports them so there is no need to use the other format. Might as well convert the existing strings to reduce the risk of people accidentally cargo culting the wrong doc string format. --- setup.cfg | 3 + synapse/handlers/_base.py | 27 +++++---- synapse/handlers/auth.py | 26 ++++++--- synapse/handlers/federation.py | 23 ++++---- synapse/handlers/room_member.py | 46 +++++++-------- synapse/handlers/sync.py | 49 ++++++++++------ synapse/http/servlet.py | 81 +++++++++++++++++---------- synapse/notifier.py | 15 ++--- synapse/push/baserules.py | 8 ++- synapse/rest/client/v2_alpha/sync.py | 73 +++++++++++++----------- synapse/state.py | 19 ++++--- synapse/storage/event_push_actions.py | 5 +- synapse/storage/registration.py | 15 +++-- synapse/storage/state.py | 13 +++-- 14 files changed, 238 insertions(+), 165 deletions(-) diff --git a/setup.cfg b/setup.cfg index f8cc13c84..5ebce1c56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,6 @@ ignore = [flake8] max-line-length = 90 ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. + +[pep8] +max-line-length = 90 diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb..5601ecea6 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -41,8 +41,9 @@ class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.events.StateStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -65,11 +66,12 @@ class BaseHandler(object): """ Returns dict of user_id -> list of events that user is allowed to see. - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events + Args: + user_tuples (str, bool): (user id, is_peeking) for each user to be + checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the + given events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -165,13 +167,16 @@ class BaseHandler(object): """ Check which events a user is allowed to see - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: + Args: + user_id(str): user id to be checked + events([synapse.events.EventBase]): list of events to be checked + is_peeking(bool): should be True if: * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events - :rtype [synapse.events.EventBase] + + Returns: + [synapse.events.EventBase] """ types = ( (EventTypes.RoomHistoryVisibility, ""), diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b42..d5d6faa85 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -163,9 +163,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +183,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,9 +196,11 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d3..092802b97 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1706,13 +1706,15 @@ class FederationHandler(BaseHandler): def _check_signature(self, event, auth_events): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + auth_events (dict<(event type, state_key), event>): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] @@ -1754,12 +1756,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adc..01f833c37 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -411,7 +411,7 @@ class RoomMemberHandler(BaseHandler): address (str): The third party identifier (e.g. "foo@example.com"). Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. + str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( @@ -545,29 +545,29 @@ class RoomMemberHandler(BaseHandler): """ Asks an identity server for a third party invite. - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e..20a062657 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -671,7 +671,8 @@ class SyncHandler(BaseHandler): def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None, recents=None, newly_joined_room=False): """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): filtering_factor = 2 @@ -838,8 +839,11 @@ class SyncHandler(BaseHandler): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ state = yield self.store.get_state_for_event(event.event_id) if event.is_state(): @@ -850,9 +854,13 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -873,15 +881,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -953,11 +964,13 @@ class SyncHandler(BaseHandler): Check if the user has just joined the given room (so should be given the full state) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync + Args: + sync_config(synapse.handlers.sync.SyncConfig): + state_delta(dict[(str,str), synapse.events.FrozenEvent]): the + difference in state since the last sync - :returns A deferred Tuple (state_delta, limited) + Returns: + A deferred Tuple (state_delta, limited) """ join_event = state_delta.get(( EventTypes.Member, sync_config.user.to_string()), None) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1c8bd8666..e41afeab8 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -26,14 +26,19 @@ logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): """Parse an integer parameter from the request string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: An int value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (int|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + int|None: An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ if name in request.args: @@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False): """Parse a boolean parameter from the request query string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: A bool value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (bool|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + bool|None: A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ @@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): """Parse a string parameter from the request query string. - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :param allowed_values (list): List of allowed values for the string, - or None if any value is allowed, defaults to None - :return: A string value or the default. - :raises + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (str|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + allowed_values (list[str]): List of allowed values for the string, + or None if any value is allowed, defaults to None + + Returns: + str|None: A string value or the default. + + Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. @@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False, def parse_json_value_from_request(request): """Parse a JSON value from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :returns: The JSON value. - :raises + Args: + request: the twisted HTTP request. + + Returns: + The JSON value. + + Raises: SynapseError if the request body couldn't be decoded as JSON. """ try: @@ -143,8 +162,10 @@ def parse_json_value_from_request(request): def parse_json_object_from_request(request): """Parse a JSON object from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :raises + Args: + request: the twisted HTTP request. + + Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index f00cd8c58..6af7a8f42 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -503,13 +503,14 @@ class Notifier(object): def wait_for_replication(self, callback, timeout): """Wait for an event to happen. - :param callback: - Gets called whenever an event happens. If this returns a truthy - value then ``wait_for_replication`` returns, otherwise it waits - for another event. - :param int timeout: - How many milliseconds to wait for callback return a truthy value. - :returns: + Args: + callback: Gets called whenever an event happens. If this returns a + truthy value then ``wait_for_replication`` returns, otherwise + it waits for another event. + timeout: How many milliseconds to wait for callback return a truthy + value. + + Returns: A deferred that resolves with the value returned by the callback. """ listener = _NotificationListener(None) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 792af70eb..6add94bee 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,9 +19,11 @@ import copy def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules - :param list rawrules: The rules the user has modified or set. - :returns: A new list with the rules set by the user combined with the - defaults. + Args: + rawrules(list): The rules the user has modified or set. + + Returns: + A new list with the rules set by the user combined with the defaults. """ ruleslist = [] diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c5785d707..60d3dc403 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -199,15 +199,17 @@ class SyncRestServlet(RestServlet): """ Encode the joined rooms in a sync result - :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync - results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs - :return: the joined rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format """ joined = {} for room in rooms: @@ -221,15 +223,17 @@ class SyncRestServlet(RestServlet): """ Encode the invited rooms in a sync result - :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format """ invited = {} for room in rooms: @@ -251,15 +255,17 @@ class SyncRestServlet(RestServlet): """ Encode the archived rooms in a sync result - :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format """ joined = {} for room in rooms: @@ -272,17 +278,18 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_room(room, time_now, token_id, joined=True): """ - :param JoinedSyncResult|ArchivedSyncResult room: sync result for a - single room - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - :param joined: True if the user is joined to this room - will mean - we handle ephemeral events + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events - :return: the room, encoded in our response format - :rtype: dict[str, object] + Returns: + dict[str, object]: the room, encoded in our response format """ def serialize(event): # TODO(mjark): Respect formatting requirements in the filter. diff --git a/synapse/state.py b/synapse/state.py index 41d32e664..4a9e148de 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -86,7 +86,8 @@ class StateHandler(object): If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. - :returns map from (type, state_key) to event + Returns: + map from (type, state_key) to event """ event_ids = yield self.store.get_latest_event_ids_in_room(room_id) @@ -176,10 +177,11 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. - :returns a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Returns: + a Deferred tuple of (`state_group`, `state`, `prev_state`). + `state_group` is the name of a state group if one and only one is + involved. `state` is a map from (type, state_key) to event, and + `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -251,9 +253,10 @@ class StateHandler(object): def _resolve_events(self, state_sets, event_type=None, state_key=""): """ - :returns a tuple (new_state, prev_states). new_state is a map - from (type, state_key) to event. prev_states is a list of event_ids. - :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. """ with Measure(self.clock, "state._resolve_events"): state = {} diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index dc5830450..3933b6e2c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ - :param event: the event set actions for - :param tuples: list of tuples of (user_id, actions) + Args: + event: the event set actions for + tuples: list of tuples of (user_id, actions) """ values = [] for uid, actions in tuples: diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a9..d46a963bb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f26..f84fd0e30 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -249,11 +249,14 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) From c906f3066152ba7a65c0765a5812b71bb8a4016c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 16:17:32 +0100 Subject: [PATCH 19/34] Do checks for memberships before creating events --- synapse/handlers/room_member.py | 237 ++++++++++++++++++----------- synapse/state.py | 8 +- tests/rest/client/v1/test_rooms.py | 4 +- 3 files changed, 151 insertions(+), 98 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adc..09b8f6217 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -95,6 +95,80 @@ class RoomMemberHandler(BaseHandler): if remotedomains is not None: remotedomains.add(member.domain) + @defer.inlineCallbacks + def _local_membership_update( + self, requester, target, room_id, membership, + txn_id=None, + ratelimit=True, + ): + msg_handler = self.hs.get_handlers().message_handler + + content = {"membership": membership} + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": membership, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, + ) + + prev_member_event = context.current_state.get( + (EventTypes.Member, target.to_string()), + None + ) + + if event.membership == Membership.JOIN: + if not prev_member_event or prev_member_event.membership != Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + yield user_joined_room(self.distributor, target, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event and prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def remote_join(self, remote_room_hosts, room_id, user, content): + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.hs.get_handlers().federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield user_joined_room(self.distributor, user, room_id) + + def reject_remote_invite(self, user_id, room_id, remote_room_hosts): + return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id + ) + @defer.inlineCallbacks def update_membership( self, @@ -120,28 +194,15 @@ class RoomMemberHandler(BaseHandler): third_party_signed, ) - msg_handler = self.hs.get_handlers().message_handler + if not remote_room_hosts: + remote_room_hosts = [] - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" - - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + current_state = yield self.state_handler.get_current_state( + room_id, latest_event_ids=latest_event_ids, ) - old_state = context.current_state.get((EventTypes.Member, event.state_key)) + old_state = current_state.get((EventTypes.Member, target.to_string())) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -156,13 +217,57 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, + is_host_in_room = self.is_host_in_room(current_state) + + if effective_membership_state == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(current_state): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + if not is_host_in_room: + inviter = yield self.get_inviter(target.to_string(), room_id) + if inviter and not self.hs.is_mine(inviter): + remote_room_hosts.append(inviter.domain) + + content = {"membership": Membership.JOIN} + if requester.is_guest: + content["kind"] = "guest" + + ret = yield self.remote_join( + remote_room_hosts, room_id, target, content + ) + defer.returnValue(ret) + + elif effective_membership_state == Membership.LEAVE: + if not is_host_in_room: + # perhaps we've been invited + inviter = yield self.get_inviter(target.to_string(), room_id) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + + yield self._local_membership_update( + requester=requester, + target=target, + room_id=room_id, + membership=effective_membership_state, + txn_id=txn_id, ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, ) @defer.inlineCallbacks @@ -211,73 +316,19 @@ class RoomMemberHandler(BaseHandler): if prev_event is not None: return - action = "send" - if event.membership == Membership.JOIN: if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter( - target_user.to_string(), context.current_state - ) - if not inviter: - raise SynapseError(404, "Not a known room") - - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) + yield self.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) prev_member_event = context.current_state.get( (EventTypes.Member, target_user.to_string()), @@ -306,11 +357,11 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - def _should_do_dance(self, context, inviter, room_hosts=None): + def _should_do_dance(self, current_state, inviter, room_hosts=None): # TODO: Shouldn't this be remote_room_host? room_hosts = room_hosts or [] - is_host_in_room = self.is_host_in_room(context.current_state) + is_host_in_room = self.is_host_in_room(current_state) if is_host_in_room: return False, room_hosts @@ -344,11 +395,11 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((RoomID.from_string(room_id), servers)) - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invite = yield self.store.get_room_member(user_id=user_id, room_id=room_id) + if invite: + defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): diff --git a/synapse/state.py b/synapse/state.py index 4672ada1b..5a5fd8ff1 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -75,7 +75,8 @@ class StateHandler(object): self._state_cache.start() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): + def get_current_state(self, room_id, event_type=None, state_key="", + latest_event_ids=None): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -88,9 +89,10 @@ class StateHandler(object): :returns map from (type, state_key) to event """ - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - res = yield self.resolve_state_groups(room_id, event_ids) + res = yield self.resolve_state_groups(room_id, latest_event_ids) state = res[1] if event_type: diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4ab8b35e6..8853cbb5f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase): # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=404) + yield self.join(room=room, user=usr, expect_code=403) + yield self.leave(room=room, user=usr, expect_code=403) @defer.inlineCallbacks def test_membership_private_room_perms(self): From aa82cb38e97faa4ad1c15109cbc5b02647ea2461 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 16:36:54 +0100 Subject: [PATCH 20/34] Remove state hack from _create_new_client_event --- synapse/handlers/_base.py | 43 --------------------------------------- 1 file changed, 43 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d407eaeee..2d8ff79a9 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -221,49 +221,6 @@ class BaseHandler(object): context = yield state_handler.compute_event_context(builder) - # If we've received an invite over federation, there are no latest - # events in the room, because we don't know enough about the graph - # fragment we received to treat it like a graph, so the above returned - # no relevant events. It may have returned some events (if we have - # joined and left the room), but not useful ones, like the invite. - if ( - not self.is_host_in_room(context.current_state) and - builder.type == EventTypes.Member - ): - prev_member_event = yield self.store.get_room_member( - builder.sender, builder.room_id - ) - - # The prev_member_event may already be in context.current_state, - # despite us not being present in the room; in particular, if - # inviting user, and all other local users, have already left. - # - # In that case, we have all the information we need, and we don't - # want to drop "context" - not least because we may need to handle - # the invite locally, which will require us to have the whole - # context (not just prev_member_event) to auth it. - # - context_event_ids = ( - e.event_id for e in context.current_state.values() - ) - - if ( - prev_member_event and - prev_member_event.event_id not in context_event_ids - ): - # The prev_member_event is missing from context, so it must - # have arrived over federation and is an outlier. We forcibly - # set our context to the invite we received over federation - builder.prev_events = ( - prev_member_event.event_id, - prev_member_event.prev_events - ) - - context = yield state_handler.compute_event_context( - builder, - old_state=(prev_member_event,) - ) - if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( context.prev_state_events From d76d89323c4b962b4f3ff72a3c9b40f2d2d347b3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 17:39:32 +0100 Subject: [PATCH 21/34] Use computed prev event ids --- synapse/handlers/_base.py | 29 +++++++++++++++++------------ synapse/handlers/message.py | 6 +++++- synapse/handlers/room_member.py | 3 +++ synapse/storage/event_federation.py | 16 ++++++++++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 2d8ff79a9..242afa156 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -199,20 +199,25 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 + def _create_new_client_event(self, builder, prev_event_ids=None): + if prev_event_ids: + prev_events = yield self.store.add_event_hashes(prev_event_ids) + prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) + depth = prev_max_depth + 1 else: - depth = 1 + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( + builder.room_id, + ) - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + if latest_ret: + depth = max([d for _, _, d in latest_ret]) + 1 + else: + depth = 1 + + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0bb111d04..10608c0dd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -176,7 +176,7 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None): + def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -187,6 +187,9 @@ class MessageHandler(BaseHandler): Args: event_dict (dict): An entire event + token_id (str) + txn_id (str) + prev_event_ids (list): The prev event ids to use when creating the event Returns: Tuple of created event (FrozenEvent), Context @@ -225,6 +228,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + prev_event_ids=prev_event_ids, ) defer.returnValue((event, context)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 09b8f6217..7c1bb8cfe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -98,6 +98,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, + prev_event_ids, txn_id=None, ratelimit=True, ): @@ -120,6 +121,7 @@ class RoomMemberHandler(BaseHandler): }, token_id=requester.access_token_id, txn_id=txn_id, + prev_event_ids=prev_event_ids, ) yield self.handle_new_client_event( @@ -268,6 +270,7 @@ class RoomMemberHandler(BaseHandler): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, + prev_event_ids=latest_event_ids, ) @defer.inlineCallbacks diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 3489315e0..082794620 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore): room_id, ) + @defer.inlineCallbacks + def get_max_depth_of_events(self, event_ids): + sql = ( + "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" + ) % (",".join(["?"] * len(event_ids)),) + + rows = yield self._execute( + "get_max_depth_of_events", None, + sql, *event_ids + ) + + if rows: + defer.returnValue(rows[0][0]) + else: + defer.returnValue(1) + def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, From 3d76b7cb2ba05fbf17be0a6647f39c419f428c16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 15:52:01 +0100 Subject: [PATCH 22/34] Store invites in a separate table. --- synapse/handlers/room_member.py | 2 +- synapse/storage/events.py | 13 +-- synapse/storage/prepare_database.py | 2 +- synapse/storage/roommember.py | 111 ++++++++++++++++---- synapse/storage/schema/delta/31/invites.sql | 28 +++++ 5 files changed, 124 insertions(+), 32 deletions(-) create mode 100644 synapse/storage/schema/delta/31/invites.sql diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7c1bb8cfe..98e346d48 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_room_member(user_id=user_id, room_id=room_id) + invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c4dc3b3d5..5d299a113 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -367,7 +367,8 @@ class EventsStore(SQLBaseStore): event for event, _ in events_and_contexts if event.type == EventTypes.Member - ] + ], + backfilled=backfilled, ) def event_dict(event): @@ -485,14 +486,8 @@ class EventsStore(SQLBaseStore): return for event, _ in state_events_and_contexts: - if (not event.internal_metadata.is_invite_from_remote() - and event.internal_metadata.is_outlier()): - # Outlier events generally shouldn't clobber the current state. - # However invites from remote severs for rooms we aren't in - # are a bit special: they don't come with any associated - # state so are technically an outlier, however all the - # client-facing code assumes that they are in the current - # state table so we insert the event anyway. + if event.internal_metadata.is_outlier(): + # Outlier events shouldn't clobber the current state. continue if context.rejected: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3f29aad1e..4099387ba 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 30 +SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 430b49c12..4c026b33a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -36,7 +36,7 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_members_txn(self, txn, events): + def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ self._simple_insert_many_txn( @@ -62,6 +62,41 @@ class RoomMemberStore(SQLBaseStore): self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + is_mine = self.hs.is_mine_id(event.state_key) + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -127,6 +162,14 @@ class RoomMemberStore(SQLBaseStore): user_id, [Membership.INVITE] ) + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + defer.returnValue(invite) + defer.returnValue(None) + def get_leave_and_ban_events_for_user(self, user_id): """ Get all the leave events for a user Args: @@ -163,29 +206,55 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["membership = ?" for _ in membership_list]), - ) - args = [user_id] - args.extend(membership_list) + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] - sql = ( - "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM current_state_events as c" - " INNER JOIN room_memberships as m" - " ON m.event_id = c.event_id" - " INNER JOIN events as e" - " ON e.event_id = c.event_id" - " AND m.room_id = c.room_id" - " AND m.user_id = c.state_key" - " WHERE %s" - ) % (where_clause,) + results = [] + if membership_list: + where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( + " OR ".join(["membership = ?" for _ in membership_list]), + ) - txn.execute(sql, args) - return [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + args = [user_id] + args.extend(membership_list) + + sql = ( + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" + " AND m.user_id = c.state_key" + " WHERE %s" + ) % (where_clause,) + + txn.execute(sql, args) + results = [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] + + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, (user_id,)) + results.extend(RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) for r in self.cursor_to_dict(txn)) + + return results @cached(max_entries=5000) def get_joined_hosts_for_room(self, room_id): diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql new file mode 100644 index 000000000..4f6fb9ea6 --- /dev/null +++ b/synapse/storage/schema/delta/31/invites.sql @@ -0,0 +1,28 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +CREATE INDEX invites_id ON invites(stream_id); +CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); From 92ab45a330c2d6c4e896786135e93b6cabfad1ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 17:07:43 +0100 Subject: [PATCH 23/34] Add upgrade path, rename table --- synapse/storage/roommember.py | 6 +++--- synapse/storage/schema/delta/31/invites.sql | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4c026b33a..abe594274 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -75,7 +75,7 @@ class RoomMemberStore(SQLBaseStore): if event.membership == Membership.INVITE: self._simple_insert_txn( txn, - table="invites", + table="local_invites", values={ "event_id": event.event_id, "invitee": event.state_key, @@ -86,7 +86,7 @@ class RoomMemberStore(SQLBaseStore): ) else: sql = ( - "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" " room_id = ? AND invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" ) @@ -239,7 +239,7 @@ class RoomMemberStore(SQLBaseStore): if do_invite: sql = ( "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM invites as i" + " FROM local_invites as i" " INNER JOIN events as e USING (event_id)" " WHERE invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 4f6fb9ea6..1c83430da 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -14,7 +14,7 @@ */ -CREATE TABLE invites( +CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, @@ -24,5 +24,19 @@ CREATE TABLE invites( replaced_by TEXT ); -CREATE INDEX invites_id ON invites(stream_id); -CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by +FROM events +NATURAL JOIN current_state_events +NATURAL JOIN room_memberships +WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); From 0c53d750e7145f57ed97c544efdb846cc9e37b67 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 18:02:48 +0100 Subject: [PATCH 24/34] Docs and indents --- synapse/handlers/room_member.py | 5 ++++- synapse/storage/roommember.py | 18 +++++++++++++++-- synapse/storage/schema/delta/31/invites.sql | 22 ++++++++++----------- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 98e346d48..f1c3e90ec 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,10 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) + invite = yield self.store.get_invite_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index abe594274..36456a75f 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,11 +66,15 @@ class RoomMemberStore(SQLBaseStore): self.get_invited_rooms_for_user.invalidate, (event.state_key,) ) - is_mine = self.hs.is_mine_id(event.state_key) + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. is_new_state = not backfilled and ( not event.internal_metadata.is_outlier() or event.internal_metadata.is_invite_from_remote() ) + is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: self._simple_insert_txn( @@ -163,7 +167,17 @@ class RoomMemberStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ invites = yield self.get_invited_rooms_for_user(user_id) for invite in invites: if invite.room_id == room_id: diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 1c83430da..2c57846d5 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -26,17 +26,17 @@ CREATE TABLE local_invites( -- Insert all invites for local users into new `invites` table INSERT INTO local_invites SELECT - stream_ordering as stream_id, - sender as inviter, - state_key as invitee, - event_id, - room_id, - NULL as locally_rejected, - NULL as replaced_by -FROM events -NATURAL JOIN current_state_events -NATURAL JOIN room_memberships -WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); CREATE INDEX local_invites_id ON local_invites(stream_id); CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); From df727f212606b771b1410c8e322fb8a99d159de4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Apr 2016 11:13:24 +0100 Subject: [PATCH 25/34] Fix stuck invites If rejecting a remote invite fails with an error response don't fail the entire request; instead mark the invite as locally rejected. This fixes the bug where users can get stuck invites which they can neither accept nor reject. --- synapse/handlers/federation.py | 34 ++++++++++++++++++++++----------- synapse/handlers/room_member.py | 18 +++++++++++++---- synapse/storage/__init__.py | 3 ++- synapse/storage/roommember.py | 19 ++++++++++++++++++ 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4049c01d2..19769eecd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -784,13 +784,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) + try: + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + signed_event = self._sign_event(event) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") # Try the host we successfully got a response to /make_join/ # request first. @@ -800,10 +806,16 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) + try: + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") context = yield self.state_handler.compute_event_context(event) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f1c3e90ec..6c7409215 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -258,10 +258,20 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - ret = yield self.reject_remote_invite( - target.to_string(), room_id, remote_room_hosts - ) - defer.returnValue(ret) + + try: + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + except SynapseError as e: + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + + defer.returnValue({}) yield self._local_membership_update( requester=requester, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 57863bba4..07916b292 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,7 +94,8 @@ class DataStore(RoomMemberStore, RoomStore, ) self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering" + db_conn, "events", "stream_ordering", + extra_tables=[("local_invites", "stream_id")] ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 36456a75f..66e7a40e3 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -102,6 +102,25 @@ class RoomMemberStore(SQLBaseStore): event.state_key, )) + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. From 1d4deff25a1edce73fb3d2f1b327d672a75581b0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 5 Apr 2016 11:23:57 +0100 Subject: [PATCH 26/34] Separate generating the replication response... from doing the http request parsing to make it easier to write unit tests for replication. --- synapse/replication/resource.py | 95 ++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index c51a6fa10..a543af68f 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -145,32 +145,43 @@ class ReplicationResource(Resource): timeout = parse_integer(request, "timeout", 10 * 1000) request.setHeader(b"Content-Type", b"application/json") - writer = _Writer(request) - @defer.inlineCallbacks + request_streams = { + name: parse_integer(request, name) + for names in STREAM_NAMES for name in names + } + request_streams["streams"] = parse_string(request, "streams") + def replicate(): - current_token = yield self.current_replication_token() - logger.info("Replicating up to %r", current_token) + return self.replicate(request_streams, limit) - yield self.account_data(writer, current_token, limit) - yield self.events(writer, current_token, limit) - yield self.presence(writer, current_token) # TODO: implement limit - yield self.typing(writer, current_token) # TODO: implement limit - yield self.receipts(writer, current_token, limit) - yield self.push_rules(writer, current_token, limit) - yield self.pushers(writer, current_token, limit) - yield self.state(writer, current_token, limit) - self.streams(writer, current_token) + result = yield self.notifier.wait_for_replication(replicate, timeout) - logger.info("Replicated %d rows", writer.total) - defer.returnValue(writer.total) + request.write(json.dumps(result, ensure_ascii=False)) + finish_request(request) - yield self.notifier.wait_for_replication(replicate, timeout) + @defer.inlineCallbacks + def replicate(self, request_streams, limit): + writer = _Writer() + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) - writer.finish() + yield self.account_data(writer, current_token, limit, request_streams) + yield self.events(writer, current_token, limit, request_streams) + # TODO: implement limit + yield self.presence(writer, current_token, request_streams) + yield self.typing(writer, current_token, request_streams) + yield self.receipts(writer, current_token, limit, request_streams) + yield self.push_rules(writer, current_token, limit, request_streams) + yield self.pushers(writer, current_token, limit, request_streams) + yield self.state(writer, current_token, limit, request_streams) + self.streams(writer, current_token, request_streams) - def streams(self, writer, current_token): - request_token = parse_string(writer.request, "streams") + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.finish()) + + def streams(self, writer, current_token, request_streams): + request_token = request_streams.get("streams") streams = [] @@ -195,9 +206,9 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def events(self, writer, current_token, limit): - request_events = parse_integer(writer.request, "events") - request_backfill = parse_integer(writer.request, "backfill") + def events(self, writer, current_token, limit, request_streams): + request_events = request_streams.get("events") + request_backfill = request_streams.get("backfill") if request_events is not None or request_backfill is not None: if request_events is None: @@ -228,10 +239,10 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def presence(self, writer, current_token): + def presence(self, writer, current_token, request_streams): current_position = current_token.presence - request_presence = parse_integer(writer.request, "presence") + request_presence = request_streams.get("presence") if request_presence is not None: presence_rows = yield self.presence_handler.get_all_presence_updates( @@ -244,10 +255,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def typing(self, writer, current_token): + def typing(self, writer, current_token, request_streams): current_position = current_token.presence - request_typing = parse_integer(writer.request, "typing") + request_typing = request_streams.get("typing") if request_typing is not None: typing_rows = yield self.typing_handler.get_all_typing_updates( @@ -258,10 +269,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def receipts(self, writer, current_token, limit): + def receipts(self, writer, current_token, limit, request_streams): current_position = current_token.receipts - request_receipts = parse_integer(writer.request, "receipts") + request_receipts = request_streams.get("receipts") if request_receipts is not None: receipts_rows = yield self.store.get_all_updated_receipts( @@ -272,12 +283,12 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def account_data(self, writer, current_token, limit): + def account_data(self, writer, current_token, limit, request_streams): current_position = current_token.account_data - user_account_data = parse_integer(writer.request, "user_account_data") - room_account_data = parse_integer(writer.request, "room_account_data") - tag_account_data = parse_integer(writer.request, "tag_account_data") + user_account_data = request_streams.get("user_account_data") + room_account_data = request_streams.get("room_account_data") + tag_account_data = request_streams.get("tag_account_data") if user_account_data is not None or room_account_data is not None: if user_account_data is None: @@ -303,10 +314,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit): + def push_rules(self, writer, current_token, limit, request_streams): current_position = current_token.push_rules - push_rules = parse_integer(writer.request, "push_rules") + push_rules = request_streams.get("push_rules") if push_rules is not None: rows = yield self.store.get_all_push_rule_updates( @@ -318,10 +329,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def pushers(self, writer, current_token, limit): + def pushers(self, writer, current_token, limit, request_streams): current_position = current_token.pushers - pushers = parse_integer(writer.request, "pushers") + pushers = request_streams.get("pushers") + if pushers is not None: updated, deleted = yield self.store.get_all_updated_pushers( pushers, current_position, limit @@ -336,10 +348,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def state(self, writer, current_token, limit): + def state(self, writer, current_token, limit, request_streams): current_position = current_token.state - state = parse_integer(writer.request, "state") + state = request_streams.get("state") + if state is not None: state_groups, state_group_state = ( yield self.store.get_all_new_state_groups( @@ -356,9 +369,8 @@ class ReplicationResource(Resource): class _Writer(object): """Writes the streams as a JSON object as the response to the request""" - def __init__(self, request): + def __init__(self): self.streams = {} - self.request = request self.total = 0 def write_header_and_rows(self, name, rows, fields, position=None): @@ -377,8 +389,7 @@ class _Writer(object): self.total += len(rows) def finish(self): - self.request.write(json.dumps(self.streams, ensure_ascii=False)) - finish_request(self.request) + return self.streams class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( From 6222ae51cee230fc746d0706db13d8928f28234b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Apr 2016 12:56:29 +0100 Subject: [PATCH 27/34] Don't backfill from self --- synapse/handlers/federation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 267fedf11..edffa560b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -289,6 +289,9 @@ class FederationHandler(BaseHandler): def backfill(self, dest, room_id, limit, extremities=[]): """ Trigger a backfill request to `dest` for the given `room_id` """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + if not extremities: extremities = yield self.store.get_oldest_events_in_room(room_id) @@ -455,7 +458,7 @@ class FederationHandler(BaseHandler): likely_domains = [ domain for domain, depth in curr_domains - if domain is not self.server_name + if domain != self.server_name ] @defer.inlineCallbacks From a1e0d316ea354fce07939073d9afc9c5d1013939 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:05:19 +0100 Subject: [PATCH 28/34] Move _get_cache_dict into the SQLBaseStore --- synapse/storage/__init__.py | 33 --------------------------------- synapse/storage/_base.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 07916b292..045ae6c03 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -177,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore, self.__presence_on_startup = None return active_on_startup - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() - - cache = { - row[0]: int(row[1]) - for row in rows - } - - if cache: - min_val = min(cache.values()) - else: - min_val = max_value - - return cache, min_val - def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b75b79df3..04d7fcf6d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -816,6 +816,40 @@ class SQLBaseStore(object): self._next_stream_id += 1 return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying From 87f2dec8d475f038beb138bc56e3ef76fcb83ec6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:08:05 +0100 Subject: [PATCH 29/34] Make the cache objects be per instance rather than being global --- synapse/storage/receipts.py | 4 +-- synapse/storage/registration.py | 2 +- synapse/storage/state.py | 4 +-- synapse/util/caches/descriptors.py | 45 ++++++++++++++++-------------- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e..7fdd84bbd 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb..1f71773aa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1, inlineCallbacks=True) def are_guests(self, user_ids): sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f940601..c5d2a3a6d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19f..758f5982b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, From 8aab9d87fa6739345810f0edf3982fe7f898ee30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Apr 2016 14:08:18 +0100 Subject: [PATCH 30/34] Don't require config to create database --- scripts/synapse_port_db | 13 ++-- synapse/app/homeserver.py | 15 +++-- synapse/storage/engines/__init__.py | 6 +- synapse/storage/engines/postgres.py | 8 +-- synapse/storage/engines/sqlite3.py | 13 +--- synapse/storage/prepare_database.py | 64 ++++++------------- .../schema/delta/14/upgrade_appservice_db.py | 6 +- synapse/storage/schema/delta/20/pushers.py | 6 +- synapse/storage/schema/delta/25/fts.py | 6 +- synapse/storage/schema/delta/27/ts.py | 6 +- synapse/storage/schema/delta/30/as_users.py | 4 +- tests/storage/test_base.py | 2 +- tests/utils.py | 6 +- 13 files changed, 69 insertions(+), 86 deletions(-) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a2a0f364c..253a6ef6c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -19,6 +19,7 @@ from twisted.enterprise import adbapi from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database import argparse import curses @@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = { "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], + "presence_stream": ["currently_active"], } @@ -292,7 +294,7 @@ class Porter(object): } ) - database_engine.prepare_database(db_conn) + prepare_database(db_conn, database_engine, config=None) db_conn.commit() @@ -309,8 +311,8 @@ class Porter(object): **self.postgres_config["args"] ) - sqlite_engine = create_engine(FakeConfig(sqlite_config)) - postgres_engine = create_engine(FakeConfig(postgres_config)) + sqlite_engine = create_engine(sqlite_config) + postgres_engine = create_engine(postgres_config) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) @@ -792,8 +794,3 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) - - -class FakeConfig: - def __init__(self, database_config): - self.database_config = database_config diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fcdc8e6e1..2b4473b9a 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -33,7 +33,7 @@ from synapse.python_dependencies import ( from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.server import HomeServer @@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self): + def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { @@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer): } db_conn = self.database_engine.module.connect(**db_params) - self.database_engine.on_new_connection(db_conn) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) return db_conn @@ -386,7 +387,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config) + database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( @@ -402,8 +403,10 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn() - database_engine.prepare_database(db_conn) + db_conn = hs.get_db_conn(run_new_connection=False) + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) + hs.run_startup_checks(db_conn, database_engine) db_conn.commit() diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a48230b93..7bb5de1fe 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,13 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(config): - name = config.database_config["name"] +def create_engine(database_config): + name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module, config=config) + return engine_class(module) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a09685b4d..c2290943b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) - self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,9 +41,6 @@ class PostgresEngine(object): self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) - def prepare_database(self, db_conn): - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): return error.pgcode in ["40001", "40P01"] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 522b90594..14203aa50 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import ( - prepare_database, prepare_sqlite3_database -) +from synapse.storage.prepare_database import prepare_database import struct @@ -23,9 +21,8 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module - self.config = config def check_database(self, txn): pass @@ -34,13 +31,9 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - self.prepare_database(db_conn) + prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) - def prepare_database(self, db_conn): - prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4099387ba..00833422a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. + + If `config` is None then prepare_database will assert that no upgrade is + necessary, *or* will create a fresh database if the database is empty. """ try: cur = db_conn.cursor() @@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config): if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config - ) - else: - _setup_new_database(cur, database_engine, config) - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + if config is None: + if user_version != SCHEMA_VERSION: + # If we don't pass in a config file then we are expecting to + # have already upgraded the DB. + raise UpgradeDatabaseException("Database needs to be upgraded") + else: + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine, config + ) + else: + _setup_new_database(cur, database_engine) cur.close() db_conn.commit() @@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine, config): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config): applied_delta_files=[], upgraded=False, database_engine=database_engine, - config=config, + config=None, + is_empty=True, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config): + upgraded, database_engine, config, is_empty=False): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine, config=config) + module.run_create(cur, database_engine) + if not is_empty: + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 5c40a7775..8755bb2e4 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, *args, **kwargs): +def run_create(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: @@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs): "UPDATE application_services_regex SET regex=? WHERE id=?", (new_regex, row[0]) ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 29164732a..147496a38 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index d3ff2b177..4269ac69a 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -43,7 +43,7 @@ SQLITE_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index f8c16391a..71b12a273 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -27,7 +27,7 @@ ALTER_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4f6e9dd54..b417e3ac0 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): # Maybe we already added the column? Hope so... pass + +def run_upgrade(cur, database_engine, config, *args, **kwargs): cur.execute("SELECT name FROM users") rows = cur.fetchall() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 2e33beb07..afbefb2e2 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), ) self.datastore = SQLBaseStore(hs) diff --git a/tests/utils.py b/tests/utils.py index 52405502e..c179df31e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), **kargs ) @@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): return conn def create_engine(self): - return create_engine(self.config) + return create_engine(self.config.database_config) class MemoryDataStore(object): From 75fb9ac1be0fada60cdde38153ac0e3fe2b19a0a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 14:12:51 +0100 Subject: [PATCH 31/34] Add a slaved events store class Add a test to check that get_room_names_and_aliases does the same thing on both the master and on the slave data store. --- synapse/replication/slave/__init__.py | 14 ++ synapse/replication/slave/storage/__init__.py | 14 ++ synapse/replication/slave/storage/_base.py | 28 +++ .../slave/storage/_slaved_id_tracker.py | 30 +++ synapse/replication/slave/storage/events.py | 198 ++++++++++++++++++ synapse/storage/events.py | 4 +- tests/replication/slave/__init__.py | 14 ++ tests/replication/slave/storage/__init__.py | 14 ++ tests/replication/slave/storage/_base.py | 57 +++++ .../replication/slave/storage/test_events.py | 114 ++++++++++ 10 files changed, 485 insertions(+), 2 deletions(-) create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_base.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py create mode 100644 synapse/replication/slave/storage/events.py create mode 100644 tests/replication/slave/__init__.py create mode 100644 tests/replication/slave/storage/__init__.py create mode 100644 tests/replication/slave/storage/_base.py create mode 100644 tests/replication/slave/storage/test_events.py diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 000000000..46e43ce1c --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from twisted.internet import defer + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + + def stream_positions(self): + return {} + + def process_replication(self, result): + return defer.succeed(None) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 000000000..24b5c79d4 --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 000000000..68b924e37 --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.room import RoomStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.state import StateStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + _set_before_and_after = DataStore._set_before_and_after + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _parse_events_txn = DataStore._parse_events_txn.__func__ + _get_events_txn = DataStore._get_events_txn.__func__ + _fetch_events_txn = DataStore._fetch_events_txn.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfilled"] = self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self._invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def _invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_room_name_and_aliases.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_users_in_room.invalidate((event.room_id,)) + # self._membership_stream_cache.entity_has_changed( + # event.state_key, event.internal_metadata.stream_ordering + # ) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + self.get_room_name_and_aliases.invalidate( + (event.room_id,) + ) + pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d299a113..ee87a7171 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1134,7 +1134,7 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT -event_stream_ordering FROM current_state_resets" + "SELECT event_stream_ordering FROM current_state_resets" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering ASC" @@ -1143,7 +1143,7 @@ class EventsStore(SQLBaseStore): state_resets = txn.fetchall() sql = ( - "SELECT -event_stream_ordering, event_id, state_group" + "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/tests/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/tests/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py new file mode 100644 index 000000000..0f525a894 --- /dev/null +++ b/tests/replication/slave/storage/_base.py @@ -0,0 +1,57 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from tests import unittest + +from synapse.replication.slave.storage.events import SlavedEventStore + +from mock import Mock, NonCallableMock +from tests.utils import setup_test_homeserver +from synapse.replication.resource import ReplicationResource + + +class BaseSlavedStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "blue", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.replication = ReplicationResource(self.hs) + + self.master_store = self.hs.get_datastore() + self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.event_id = 0 + + @defer.inlineCallbacks + def replicate(self): + streams = self.slaved_store.stream_positions() + result = yield self.replication.replicate(streams, 100) + yield self.slaved_store.process_replication(result) + + @defer.inlineCallbacks + def check(self, method, args, expected_result=None): + master_result = yield getattr(self.master_store, method)(*args) + slaved_result = yield getattr(self.slaved_store, method)(*args) + self.assertEqual(master_result, slaved_result) + if expected_result is not None: + self.assertEqual(master_result, expected_result) + self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py new file mode 100644 index 000000000..c30c7c606 --- /dev/null +++ b/tests/replication/slave/storage/test_events.py @@ -0,0 +1,114 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.types import UserID +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +USER = UserID.from_string(USER_ID) +OUTLIER = {"outlier": True} +ROOM_ID = "!room:blue" + + +class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + + @defer.inlineCallbacks + def test_room_name_and_aliases(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.persist(type="m.room.name", key="", name="name1") + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#1:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) + ) + + # Set the room name. + yield self.persist(type="m.room.name", key="", name="name2") + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) + ) + + # Set the room aliases. + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#2:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) + ) + + # Leave and join the room clobbering the state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), (None, []) + ) + + event_id = 0 + + @defer.inlineCallbacks + def persist( + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, + internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content + ): + if depth is None: + depth = self.event_id + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + event = FrozenEvent(event_dict, internal_metadata_dict=internal) + + self.event_id += 1 + + context = EventContext(current_state=state) + + if backfill: + yield self.master_store.persist_events( + [(event, context)], backfilled=True + ) + else: + yield self.master_store.persist_event( + event, context, current_state=reset_state + ) + defer.returnValue(event) From 1e05637e37f62445d84e43ae89e441f1833a32e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 Apr 2016 15:19:45 +0100 Subject: [PATCH 32/34] Let users see their own leave events ... otherwise clients get confused. Fixes https://matrix.org/jira/browse/SYN-662, https://github.com/vector-im/vector-web/issues/368 --- synapse/handlers/_base.py | 51 ++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c77afe7f5..88d8b9ba5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = ( ) +MEMBERSHIP_PRIORITY = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, +) + + class BaseHandler(object): """ Common base class for the event handlers. @@ -72,6 +81,7 @@ class BaseHandler(object): * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events + events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -86,6 +96,12 @@ class BaseHandler(object): ) def allowed(event, user_id, is_peeking): + """ + Args: + event (synapse.events.EventBase): event to check + user_id (str) + is_peeking (bool) + """ state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -117,17 +133,30 @@ class BaseHandler(object): if old_priority < new_priority: visibility = prev_visibility - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None + # likewise, if the event is the user's own membership event, use + # the 'most joined' membership + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: + membership = event.content.get("membership", None) + if membership not in MEMBERSHIP_PRIORITY: + membership = "leave" + + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership not in MEMBERSHIP_PRIORITY: + prev_membership = "leave" + + new_priority = MEMBERSHIP_PRIORITY.index(membership) + old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) + if old_priority < new_priority: + membership = prev_membership + + # otherwise, get the user's membership at the time of the event. + if membership is None: + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id not in event_id_forgotten: + membership = membership_event.membership # if the user was a member of the room at the time of the event, # they can see it. From 6bfec56796132520ad864ad00f8dd6773512f9f4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 16:17:15 +0100 Subject: [PATCH 33/34] Test that room membership is replicated --- synapse/replication/slave/storage/events.py | 7 +- .../replication/slave/storage/test_events.py | 71 ++++++++++++++++--- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 68b924e37..680dc8953 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -71,9 +71,6 @@ class SlavedEventStore(BaseSlavedStore): get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -95,6 +92,10 @@ class SlavedEventStore(BaseSlavedStore): _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index c30c7c606..351d777fb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,14 +14,14 @@ from ._base import BaseSlavedStoreTestCase -from synapse.types import UserID from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext +from synapse.storage.roommember import RoomsForUser from twisted.internet import defer USER_ID = "@feeling:blue" -USER = UserID.from_string(USER_ID) +USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" @@ -69,16 +69,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_room_name_and_aliases", (ROOM_ID,), (None, []) ) + @defer.inlineCallbacks + def test_room_members(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Join the room. + join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + + # Leave the room. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Add some other user to the room. + join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) + + # Join the room clobbering the state. + # This should remove any evidence of the other user being in the room. + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + yield self.check("get_rooms_for_user", (USER_ID_2,), []) + event_id = 0 @defer.inlineCallbacks def persist( - self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, - internal={}, - state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], - **content + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content ): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ if depth is None: depth = self.event_id @@ -103,12 +153,17 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): context = EventContext(current_state=state) + ordering = None if backfill: yield self.master_store.persist_events( [(event, context)], backfilled=True ) else: - yield self.master_store.persist_event( + ordering, _ = yield self.master_store.persist_event( event, context, current_state=reset_state ) + + if ordering: + event.internal_metadata.stream_ordering = ordering + defer.returnValue(event) From 1ef036567051218c38de5529472d3c9000c6960d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Apr 2016 09:42:52 +0100 Subject: [PATCH 34/34] Set profile information when joining rooms remotely --- synapse/handlers/room_member.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fe2315df8..8c41cb6f3 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -233,6 +233,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + if requester.is_guest: content["kind"] = "guest"