From e2e1d90a5e4030616a3de242cde26c0cfff4a6b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Mar 2022 12:49:54 +0000 Subject: [PATCH] Faster joins: persist to database (#12012) When we get a partial_state response from send_join, store information in the database about it: * store a record about the room as a whole having partial state, and stash the list of member servers too. * flag the join event itself as having partial state * also, for any new events whose prev-events are partial-stated, note that they will *also* be partial-stated. We don't yet make any attempt to interpret this data, so API calls (and a bunch of other things) are just going to get incorrect data. --- changelog.d/12012.misc | 1 + synapse/events/snapshot.py | 9 +++ synapse/handlers/federation.py | 11 ++- synapse/handlers/federation_event.py | 13 +++- synapse/handlers/message.py | 2 + synapse/state/__init__.py | 31 +++++++- synapse/storage/databases/main/events.py | 25 +++++++ .../storage/databases/main/events_worker.py | 28 ++++++++ synapse/storage/databases/main/room.py | 37 ++++++++++ .../main/delta/68/04partial_state_rooms.sql | 41 +++++++++++ .../68/05partial_state_rooms_triggers.py | 72 +++++++++++++++++++ tests/test_state.py | 59 ++++++++------- 12 files changed, 297 insertions(+), 32 deletions(-) create mode 100644 changelog.d/12012.misc create mode 100644 synapse/storage/schema/main/delta/68/04partial_state_rooms.sql create mode 100644 synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py diff --git a/changelog.d/12012.misc b/changelog.d/12012.misc new file mode 100644 index 000000000..a473f41e7 --- /dev/null +++ b/changelog.d/12012.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5833fee25..46042b2bf 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -101,6 +101,9 @@ class EventContext: As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. + + partial_state: if True, we may be storing this event with a temporary, + incomplete state. """ rejected: Union[bool, str] = False @@ -113,12 +116,15 @@ class EventContext: _current_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None + partial_state: bool = False + @staticmethod def with_state( state_group: Optional[int], state_group_before_event: Optional[int], current_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]], + partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": @@ -129,6 +135,7 @@ class EventContext: state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, + partial_state=partial_state, ) @staticmethod @@ -170,6 +177,7 @@ class EventContext: "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, + "partial_state": self.partial_state, } @staticmethod @@ -196,6 +204,7 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], + partial_state=input.get("partial_state", False), ) app_service_id = input["app_service_id"] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c055c26ec..eb03a5acc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -519,8 +519,17 @@ class FederationHandler: state_events=state, ) + if ret.partial_state: + await self.store.store_partial_state_room(room_id, ret.servers_in_room) + max_stream_id = await self._federation_event_handler.process_remote_join( - origin, room_id, auth_chain, state, event, room_version_obj + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, ) # We wait here until this instance has seen the events come down diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 09d0de1ea..4bd87709f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -397,6 +397,7 @@ class FederationEventHandler: state: List[EventBase], event: EventBase, room_version: RoomVersion, + partial_state: bool, ) -> int: """Persists the events returned by a send_join @@ -412,6 +413,7 @@ class FederationEventHandler: event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + partial_state: True if the state omits non-critical membership events Returns: The stream ID after which all events have been persisted. @@ -453,10 +455,14 @@ class FederationEventHandler: ) # and now persist the join event itself. - logger.info("Peristing join-via-remote %s", event) + logger.info( + "Peristing join-via-remote %s (partial_state: %s)", event, partial_state + ) with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( - event, old_state=state + event, + old_state=state, + partial_state=partial_state, ) context = await self._check_event_auth(origin, event, context) @@ -698,6 +704,8 @@ class FederationEventHandler: try: state = await self._resolve_state_at_missing_prevs(origin, event) + # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does + # not return partial state await self._process_received_pdu( origin, event, state=state, backfilled=backfilled ) @@ -1791,6 +1799,7 @@ class FederationEventHandler: prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, + partial_state=context.partial_state, ) async def _run_push_actions_and_persist_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ce1fa3c78..61cb133ef 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -992,6 +992,8 @@ class EventCreationHandler: and full_state_ids_at_event and builder.internal_metadata.is_historical() ): + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. old_state = await self.store.get_events_as_list(full_state_ids_at_event) context = await self.state.compute_event_context(event, old_state=old_state) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fcc24ad12..6babd5963 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -258,7 +258,10 @@ class StateHandler: return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ class StateHandler: calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ class StateHandler: else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 23fa089bc..ca2a9ba9d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2145,6 +2145,14 @@ class PersistEventsStore: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): + # double-check that we don't have any events that claim to be outliers + # *and* have partial state (which is meaningless: we should have no + # state at all for an outlier) + if context.partial_state: + raise ValueError( + "Outlier event %s claims to have partial state", event.event_id + ) + continue # if the event was rejected, just give it the same state as its @@ -2155,6 +2163,23 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group + # if we have partial state for these events, record the fact. (This happens + # here rather than in _store_event_txn because it also needs to happen when + # we de-outlier an event.) + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 2a255d103..26784f755 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore): "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, ) + + @cachedList("is_partial_state_event", list_name="event_ids") + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + """Checks which of the given events have partial state""" + result = await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ) + # convert the result to a dict, to make @cachedList work + partial = {r["event_id"] for r in result} + return {e_id: e_id in partial for e_id in event_ids} + + @cached() + async def is_partial_state_event(self, event_id: str) -> bool: + """Checks if the given event has partial state""" + result = await self.db_pool.simple_select_one_onecol( + table="partial_state_events", + keyvalues={"event_id": event_id}, + retcol="1", + allow_none=True, + desc="is_partial_state_event", + ) + return result is not None diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0416df64c..94068940b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Collection, Dict, List, Optional, @@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): lock=False, ) + async def store_partial_state_room( + self, + room_id: str, + servers: Collection[str], + ) -> None: + """Mark the given room as containing events with partial state + + Args: + room_id: the ID of the room + servers: other servers known to be in the room + """ + await self.db_pool.runInteraction( + "store_partial_state_room", + self._store_partial_state_room_txn, + room_id, + servers, + ) + + @staticmethod + def _store_partial_state_room_txn( + txn: LoggingTransaction, room_id: str, servers: Collection[str] + ) -> None: + DatabasePool.simple_insert_txn( + txn, + table="partial_state_rooms", + values={ + "room_id": room_id, + }, + ) + DatabasePool.simple_insert_many_txn( + txn, + table="partial_state_rooms_servers", + keys=("room_id", "server_name"), + values=((room_id, s) for s in servers), + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql new file mode 100644 index 000000000..815c0cc39 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql @@ -0,0 +1,41 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- rooms which we have done a partial-state-style join to +CREATE TABLE IF NOT EXISTS partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); + +-- a list of remote servers we believe are in the room +CREATE TABLE IF NOT EXISTS partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); + +-- a list of events with partial state. We can't store this in the `events` table +-- itself, because `events` is meant to be append-only. +CREATE TABLE IF NOT EXISTS partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); + +CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx + ON partial_state_events (room_id); + + diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py new file mode 100644 index 000000000..a2ec4fc26 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -0,0 +1,72 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the partial_state_events tables to enforce uniqueness + +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # complain if the room_id in partial_state_events doesn't match + # that in `events`. We already have a fk constraint which ensures that the event + # exists in `events`, so all we have to do is raise if there is a row with a + # matching stream_ordering but not a matching room_id. + if isinstance(database_engine, Sqlite3Engine): + cur.execute( + """ + CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + cur.execute( + """ + CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events + FOR EACH ROW + EXECUTE PROCEDURE check_partial_state_events() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/tests/test_state.py b/tests/test_state.py index 90800421f..e4baa6913 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Collection, Dict, List, Optional from unittest.mock import Mock from twisted.internet import defer @@ -70,7 +70,7 @@ def create_event( return event -class StateGroupStore: +class _DummyStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -105,6 +105,11 @@ class StateGroupStore: if e_id in self._event_id_to_event } + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + return {e: False for e in event_ids} + async def get_state_group_delta(self, name): return None, None @@ -157,8 +162,8 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): - self.store = StateGroupStore() - storage = Mock(main=self.store, state=self.store) + self.dummy_store = _DummyStore() + storage = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", @@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastores.return_value = Mock(main=self.store) + hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) @@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store: dict[str, EventContext] = {} @@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] @@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C @@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other @@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels @@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - self.store.register_events(old_state_1) - self.store.register_events(old_state_2) + self.dummy_store.register_events(old_state_1) + self.dummy_store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase): self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_1, event.room_id, None, @@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_1}, ) ) - self.store.register_event_id_state_group(prev_event_id_1, sg1) + self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_2, event.room_id, None, @@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_2}, ) ) - self.store.register_event_id_state_group(prev_event_id_2, sg2) + self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result