diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a2a0f364c..253a6ef6c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -19,6 +19,7 @@ from twisted.enterprise import adbapi from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database import argparse import curses @@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = { "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], + "presence_stream": ["currently_active"], } @@ -292,7 +294,7 @@ class Porter(object): } ) - database_engine.prepare_database(db_conn) + prepare_database(db_conn, database_engine, config=None) db_conn.commit() @@ -309,8 +311,8 @@ class Porter(object): **self.postgres_config["args"] ) - sqlite_engine = create_engine(FakeConfig(sqlite_config)) - postgres_engine = create_engine(FakeConfig(postgres_config)) + sqlite_engine = create_engine(sqlite_config) + postgres_engine = create_engine(postgres_config) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) @@ -792,8 +794,3 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) - - -class FakeConfig: - def __init__(self, database_config): - self.database_config = database_config diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fcdc8e6e1..2b4473b9a 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -33,7 +33,7 @@ from synapse.python_dependencies import ( from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.server import HomeServer @@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self): + def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { @@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer): } db_conn = self.database_engine.module.connect(**db_params) - self.database_engine.on_new_connection(db_conn) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) return db_conn @@ -386,7 +387,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config) + database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( @@ -402,8 +403,10 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn() - database_engine.prepare_database(db_conn) + db_conn = hs.get_db_conn(run_new_connection=False) + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) + hs.run_startup_checks(db_conn, database_engine) db_conn.commit() diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 9c92ea01e..aaf6b1b83 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -37,6 +37,15 @@ VISIBILITY_PRIORITY = ( ) +MEMBERSHIP_PRIORITY = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, +) + + class BaseHandler(object): """ Common base class for the event handlers. @@ -72,6 +81,7 @@ class BaseHandler(object): * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events + events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -86,6 +96,12 @@ class BaseHandler(object): ) def allowed(event, user_id, is_peeking): + """ + Args: + event (synapse.events.EventBase): event to check + user_id (str) + is_peeking (bool) + """ state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. @@ -117,17 +133,30 @@ class BaseHandler(object): if old_priority < new_priority: visibility = prev_visibility - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None + # likewise, if the event is the user's own membership event, use + # the 'most joined' membership + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: + membership = event.content.get("membership", None) + if membership not in MEMBERSHIP_PRIORITY: + membership = "leave" + + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership not in MEMBERSHIP_PRIORITY: + prev_membership = "leave" + + new_priority = MEMBERSHIP_PRIORITY.index(membership) + old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) + if old_priority < new_priority: + membership = prev_membership + + # otherwise, get the user's membership at the time of the event. + if membership is None: + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + if membership_event.event_id not in event_id_forgotten: + membership = membership_event.membership # if the user was a member of the room at the time of the event, # they can see it. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fc5e0b059..3686738e5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -284,6 +284,9 @@ class FederationHandler(BaseHandler): def backfill(self, dest, room_id, limit, extremities=[]): """ Trigger a backfill request to `dest` for the given `room_id` """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + if not extremities: extremities = yield self.store.get_oldest_events_in_room(room_id) @@ -450,7 +453,7 @@ class FederationHandler(BaseHandler): likely_domains = [ domain for domain, depth in curr_domains - if domain is not self.server_name + if domain != self.server_name ] @defer.inlineCallbacks diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fe2315df8..8c41cb6f3 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -233,6 +233,11 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + if requester.is_guest: content["kind"] = "guest" diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index c51a6fa10..a543af68f 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -145,32 +145,43 @@ class ReplicationResource(Resource): timeout = parse_integer(request, "timeout", 10 * 1000) request.setHeader(b"Content-Type", b"application/json") - writer = _Writer(request) - @defer.inlineCallbacks + request_streams = { + name: parse_integer(request, name) + for names in STREAM_NAMES for name in names + } + request_streams["streams"] = parse_string(request, "streams") + def replicate(): - current_token = yield self.current_replication_token() - logger.info("Replicating up to %r", current_token) + return self.replicate(request_streams, limit) - yield self.account_data(writer, current_token, limit) - yield self.events(writer, current_token, limit) - yield self.presence(writer, current_token) # TODO: implement limit - yield self.typing(writer, current_token) # TODO: implement limit - yield self.receipts(writer, current_token, limit) - yield self.push_rules(writer, current_token, limit) - yield self.pushers(writer, current_token, limit) - yield self.state(writer, current_token, limit) - self.streams(writer, current_token) + result = yield self.notifier.wait_for_replication(replicate, timeout) - logger.info("Replicated %d rows", writer.total) - defer.returnValue(writer.total) + request.write(json.dumps(result, ensure_ascii=False)) + finish_request(request) - yield self.notifier.wait_for_replication(replicate, timeout) + @defer.inlineCallbacks + def replicate(self, request_streams, limit): + writer = _Writer() + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) - writer.finish() + yield self.account_data(writer, current_token, limit, request_streams) + yield self.events(writer, current_token, limit, request_streams) + # TODO: implement limit + yield self.presence(writer, current_token, request_streams) + yield self.typing(writer, current_token, request_streams) + yield self.receipts(writer, current_token, limit, request_streams) + yield self.push_rules(writer, current_token, limit, request_streams) + yield self.pushers(writer, current_token, limit, request_streams) + yield self.state(writer, current_token, limit, request_streams) + self.streams(writer, current_token, request_streams) - def streams(self, writer, current_token): - request_token = parse_string(writer.request, "streams") + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.finish()) + + def streams(self, writer, current_token, request_streams): + request_token = request_streams.get("streams") streams = [] @@ -195,9 +206,9 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def events(self, writer, current_token, limit): - request_events = parse_integer(writer.request, "events") - request_backfill = parse_integer(writer.request, "backfill") + def events(self, writer, current_token, limit, request_streams): + request_events = request_streams.get("events") + request_backfill = request_streams.get("backfill") if request_events is not None or request_backfill is not None: if request_events is None: @@ -228,10 +239,10 @@ class ReplicationResource(Resource): ) @defer.inlineCallbacks - def presence(self, writer, current_token): + def presence(self, writer, current_token, request_streams): current_position = current_token.presence - request_presence = parse_integer(writer.request, "presence") + request_presence = request_streams.get("presence") if request_presence is not None: presence_rows = yield self.presence_handler.get_all_presence_updates( @@ -244,10 +255,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def typing(self, writer, current_token): + def typing(self, writer, current_token, request_streams): current_position = current_token.presence - request_typing = parse_integer(writer.request, "typing") + request_typing = request_streams.get("typing") if request_typing is not None: typing_rows = yield self.typing_handler.get_all_typing_updates( @@ -258,10 +269,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def receipts(self, writer, current_token, limit): + def receipts(self, writer, current_token, limit, request_streams): current_position = current_token.receipts - request_receipts = parse_integer(writer.request, "receipts") + request_receipts = request_streams.get("receipts") if request_receipts is not None: receipts_rows = yield self.store.get_all_updated_receipts( @@ -272,12 +283,12 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def account_data(self, writer, current_token, limit): + def account_data(self, writer, current_token, limit, request_streams): current_position = current_token.account_data - user_account_data = parse_integer(writer.request, "user_account_data") - room_account_data = parse_integer(writer.request, "room_account_data") - tag_account_data = parse_integer(writer.request, "tag_account_data") + user_account_data = request_streams.get("user_account_data") + room_account_data = request_streams.get("room_account_data") + tag_account_data = request_streams.get("tag_account_data") if user_account_data is not None or room_account_data is not None: if user_account_data is None: @@ -303,10 +314,10 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit): + def push_rules(self, writer, current_token, limit, request_streams): current_position = current_token.push_rules - push_rules = parse_integer(writer.request, "push_rules") + push_rules = request_streams.get("push_rules") if push_rules is not None: rows = yield self.store.get_all_push_rule_updates( @@ -318,10 +329,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def pushers(self, writer, current_token, limit): + def pushers(self, writer, current_token, limit, request_streams): current_position = current_token.pushers - pushers = parse_integer(writer.request, "pushers") + pushers = request_streams.get("pushers") + if pushers is not None: updated, deleted = yield self.store.get_all_updated_pushers( pushers, current_position, limit @@ -336,10 +348,11 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def state(self, writer, current_token, limit): + def state(self, writer, current_token, limit, request_streams): current_position = current_token.state - state = parse_integer(writer.request, "state") + state = request_streams.get("state") + if state is not None: state_groups, state_group_state = ( yield self.store.get_all_new_state_groups( @@ -356,9 +369,8 @@ class ReplicationResource(Resource): class _Writer(object): """Writes the streams as a JSON object as the response to the request""" - def __init__(self, request): + def __init__(self): self.streams = {} - self.request = request self.total = 0 def write_header_and_rows(self, name, rows, fields, position=None): @@ -377,8 +389,7 @@ class _Writer(object): self.total += len(rows) def finish(self): - self.request.write(json.dumps(self.streams, ensure_ascii=False)) - finish_request(self.request) + return self.streams class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 000000000..46e43ce1c --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from twisted.internet import defer + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + + def stream_positions(self): + return {} + + def process_replication(self, result): + return defer.succeed(None) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 000000000..24b5c79d4 --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 000000000..680dc8953 --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.room import RoomStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.state import StateStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + _set_before_and_after = DataStore._set_before_and_after + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _parse_events_txn = DataStore._parse_events_txn.__func__ + _get_events_txn = DataStore._get_events_txn.__func__ + _fetch_events_txn = DataStore._fetch_events_txn.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfilled"] = self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self._invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def _invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_room_name_and_aliases.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_users_in_room.invalidate((event.room_id,)) + # self._membership_stream_cache.entity_has_changed( + # event.state_key, event.internal_metadata.stream_ordering + # ) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + self.get_room_name_and_aliases.invalidate( + (event.room_id,) + ) + pass diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 07916b292..045ae6c03 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -177,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore, self.__presence_on_startup = None return active_on_startup - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() - - cache = { - row[0]: int(row[1]) - for row in rows - } - - if cache: - min_val = min(cache.values()) - else: - min_val = max_value - - return cache, min_val - def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b75b79df3..04d7fcf6d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -816,6 +816,40 @@ class SQLBaseStore(object): self._next_stream_id += 1 return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a48230b93..7bb5de1fe 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,13 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(config): - name = config.database_config["name"] +def create_engine(database_config): + name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module, config=config) + return engine_class(module) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a09685b4d..c2290943b 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) - self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,9 +41,6 @@ class PostgresEngine(object): self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) - def prepare_database(self, db_conn): - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): return error.pgcode in ["40001", "40P01"] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 522b90594..14203aa50 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import ( - prepare_database, prepare_sqlite3_database -) +from synapse.storage.prepare_database import prepare_database import struct @@ -23,9 +21,8 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module - self.config = config def check_database(self, txn): pass @@ -34,13 +31,9 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - self.prepare_database(db_conn) + prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) - def prepare_database(self, db_conn): - prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): return False diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5be5bc01b..308a2c9b0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1148,7 +1148,7 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT -event_stream_ordering FROM current_state_resets" + "SELECT event_stream_ordering FROM current_state_resets" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering ASC" @@ -1157,7 +1157,7 @@ class EventsStore(SQLBaseStore): state_resets = txn.fetchall() sql = ( - "SELECT -event_stream_ordering, event_id, state_group" + "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4099387ba..00833422a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. + + If `config` is None then prepare_database will assert that no upgrade is + necessary, *or* will create a fresh database if the database is empty. """ try: cur = db_conn.cursor() @@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config): if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config - ) - else: - _setup_new_database(cur, database_engine, config) - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + if config is None: + if user_version != SCHEMA_VERSION: + # If we don't pass in a config file then we are expecting to + # have already upgraded the DB. + raise UpgradeDatabaseException("Database needs to be upgraded") + else: + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine, config + ) + else: + _setup_new_database(cur, database_engine) cur.close() db_conn.commit() @@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine, config): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config): applied_delta_files=[], upgraded=False, database_engine=database_engine, - config=config, + config=None, + is_empty=True, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config): + upgraded, database_engine, config, is_empty=False): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine, config=config) + module.run_create(cur, database_engine) + if not is_empty: + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 59d1ac031..3b8805593 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 5c40a7775..8755bb2e4 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, *args, **kwargs): +def run_create(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: @@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs): "UPDATE application_services_regex SET regex=? WHERE id=?", (new_regex, row[0]) ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 29164732a..147496a38 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index d3ff2b177..4269ac69a 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -43,7 +43,7 @@ SQLITE_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index f8c16391a..71b12a273 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -27,7 +27,7 @@ ALTER_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4f6e9dd54..b417e3ac0 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): # Maybe we already added the column? Hope so... pass + +def run_upgrade(cur, database_engine, config, *args, **kwargs): cur.execute("SELECT name FROM users") rows = cur.fetchall() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f940601..c5d2a3a6d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19f..758f5982b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/tests/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py new file mode 100644 index 000000000..b7df13c9e --- /dev/null +++ b/tests/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py new file mode 100644 index 000000000..0f525a894 --- /dev/null +++ b/tests/replication/slave/storage/_base.py @@ -0,0 +1,57 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from tests import unittest + +from synapse.replication.slave.storage.events import SlavedEventStore + +from mock import Mock, NonCallableMock +from tests.utils import setup_test_homeserver +from synapse.replication.resource import ReplicationResource + + +class BaseSlavedStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "blue", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.replication = ReplicationResource(self.hs) + + self.master_store = self.hs.get_datastore() + self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.event_id = 0 + + @defer.inlineCallbacks + def replicate(self): + streams = self.slaved_store.stream_positions() + result = yield self.replication.replicate(streams, 100) + yield self.slaved_store.process_replication(result) + + @defer.inlineCallbacks + def check(self, method, args, expected_result=None): + master_result = yield getattr(self.master_store, method)(*args) + slaved_result = yield getattr(self.slaved_store, method)(*args) + self.assertEqual(master_result, slaved_result) + if expected_result is not None: + self.assertEqual(master_result, expected_result) + self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py new file mode 100644 index 000000000..351d777fb --- /dev/null +++ b/tests/replication/slave/storage/test_events.py @@ -0,0 +1,169 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.storage.roommember import RoomsForUser + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +USER_ID_2 = "@bright:blue" +OUTLIER = {"outlier": True} +ROOM_ID = "!room:blue" + + +class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + + @defer.inlineCallbacks + def test_room_name_and_aliases(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.persist(type="m.room.name", key="", name="name1") + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#1:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) + ) + + # Set the room name. + yield self.persist(type="m.room.name", key="", name="name2") + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) + ) + + # Set the room aliases. + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#2:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) + ) + + # Leave and join the room clobbering the state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), (None, []) + ) + + @defer.inlineCallbacks + def test_room_members(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Join the room. + join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + + # Leave the room. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Add some other user to the room. + join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) + + # Join the room clobbering the state. + # This should remove any evidence of the other user being in the room. + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + yield self.check("get_rooms_for_user", (USER_ID_2,), []) + + event_id = 0 + + @defer.inlineCallbacks + def persist( + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content + ): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ + if depth is None: + depth = self.event_id + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + event = FrozenEvent(event_dict, internal_metadata_dict=internal) + + self.event_id += 1 + + context = EventContext(current_state=state) + + ordering = None + if backfill: + yield self.master_store.persist_events( + [(event, context)], backfilled=True + ) + else: + ordering, _ = yield self.master_store.persist_event( + event, context, current_state=reset_state + ) + + if ordering: + event.internal_metadata.stream_ordering = ordering + + defer.returnValue(event) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 2e33beb07..afbefb2e2 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), ) self.datastore = SQLBaseStore(hs) diff --git a/tests/utils.py b/tests/utils.py index 52405502e..c179df31e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), **kargs ) @@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): return conn def create_engine(self): - return create_engine(self.config) + return create_engine(self.config.database_config) class MemoryDataStore(object):