diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc new file mode 100644 index 000000000..346741d98 --- /dev/null +++ b/changelog.d/9124.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 7128dc174..e46e44ba5 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -16,6 +16,8 @@ import logging from typing import Dict, List, Optional, Tuple +import attr + from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -28,6 +30,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True) +class _CalculateChainCover: + """Return value for _calculate_chain_cover_txn. + """ + + # The last room_id/depth/stream processed. + room_id = attr.ib(type=str) + depth = attr.ib(type=int) + stream = attr.ib(type=int) + + # Number of rows processed + processed_count = attr.ib(type=int) + + # Map from room_id to last depth/stream processed for each room that we have + # processed all events for (i.e. the rooms we can flip the + # `has_auth_chain_index` for) + finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + + class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): current_room_id = progress.get("current_room_id", "") - # Have we finished processing the current room. - finished = progress.get("finished", True) - # Where we've processed up to in the room, defaults to the start of the # room. last_depth = progress.get("last_depth", -1) last_stream = progress.get("last_stream", -1) - # Have we set the `has_auth_chain_index` for the room yet. - has_set_room_has_chain_index = progress.get( - "has_set_room_has_chain_index", False + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + current_room_id, + last_depth, + last_stream, + batch_size, + single_room=False, ) - if finished: - # If we've finished with the previous room (or its our first - # iteration) we move on to the next room. + finished = result.processed_count == 0 - def _get_next_room(txn: Cursor) -> Optional[str]: - sql = """ - SELECT room_id FROM rooms - WHERE room_id > ? - AND ( - NOT has_auth_chain_index - OR has_auth_chain_index IS NULL - ) - ORDER BY room_id - LIMIT 1 - """ - txn.execute(sql, (current_room_id,)) - row = txn.fetchone() - if row: - return row[0] + total_rows_processed = result.processed_count + current_room_id = result.room_id + last_depth = result.depth + last_stream = result.stream - return None - - current_room_id = await self.db_pool.runInteraction( - "_chain_cover_index", _get_next_room - ) - if not current_room_id: - await self.db_pool.updates._end_background_update("chain_cover") - return 0 - - logger.debug("Adding chain cover to %s", current_room_id) - - def _calculate_auth_chain( - txn: Cursor, last_depth: int, last_stream: int - ) -> Tuple[int, int, int]: - # Get the next set of events in the room (that we haven't already - # computed chain cover for). We do this in topological order. - - # We want to do a `(topological_ordering, stream_ordering) > (?,?)` - # comparison, but that is not supported on older SQLite versions - tuple_clause, tuple_args = make_tuple_comparison_clause( - self.database_engine, - [ - ("topological_ordering", last_depth), - ("stream_ordering", last_stream), - ], - ) - - sql = """ - SELECT - event_id, state_events.type, state_events.state_key, - topological_ordering, stream_ordering - FROM events - INNER JOIN state_events USING (event_id) - LEFT JOIN event_auth_chains USING (event_id) - LEFT JOIN event_auth_chain_to_calculate USING (event_id) - WHERE events.room_id = ? - AND event_auth_chains.event_id IS NULL - AND event_auth_chain_to_calculate.event_id IS NULL - AND %(tuple_cmp)s - ORDER BY topological_ordering, stream_ordering - LIMIT ? - """ % { - "tuple_cmp": tuple_clause, - } - - args = [current_room_id] - args.extend(tuple_args) - args.append(batch_size) - - txn.execute(sql, args) - rows = txn.fetchall() - - # Put the results in the necessary format for - # `_add_chain_cover_index` - event_to_room_id = {row[0]: current_room_id for row in rows} - event_to_types = {row[0]: (row[1], row[2]) for row in rows} - - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - - count = len(rows) - - # We also need to fetch the auth events for them. - auth_events = self.db_pool.simple_select_many_txn( - txn, - table="event_auth", - column="event_id", - iterable=event_to_room_id, - keyvalues={}, - retcols=("event_id", "auth_id"), - ) - - event_to_auth_chain = {} # type: Dict[str, List[str]] - for row in auth_events: - event_to_auth_chain.setdefault(row["event_id"], []).append( - row["auth_id"] - ) - - # Calculate and persist the chain cover index for this set of events. - # - # Annoyingly we need to gut wrench into the persit event store so that - # we can reuse the function to calculate the chain cover for rooms. - PersistEventsStore._add_chain_cover_index( - txn, - self.db_pool, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) - - return new_last_depth, new_last_stream, count - - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream - ) - - total_rows_processed = count - - if count < batch_size and not has_set_room_has_chain_index: + for room_id, (depth, stream) in result.finished_room_map.items(): # If we've done all the events in the room we flip the # `has_auth_chain_index` in the DB. Note that its possible for # further events to be persisted between the above and setting the @@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.simple_update( table="rooms", - keyvalues={"room_id": current_room_id}, + keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": True}, desc="_chain_cover_index", ) - has_set_room_has_chain_index = True # Handle any events that might have raced with us flipping the # bit above. - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + room_id, + depth, + stream, + batch_size=None, + single_room=True, ) - total_rows_processed += count + total_rows_processed += result.processed_count - # Note that at this point its technically possible that more events - # than our `batch_size` have been persisted without their chain - # cover, so we need to continue processing this room if the last - # count returned was equal to the `batch_size`. + if finished: + await self.db_pool.updates._end_background_update("chain_cover") + return total_rows_processed - if count < batch_size: - # We've finished calculating the index for this room, move on to the - # next room. - await self.db_pool.updates._background_update_progress( - "chain_cover", {"current_room_id": current_room_id, "finished": True}, - ) - else: - # We still have outstanding events to calculate the index for. - await self.db_pool.updates._background_update_progress( - "chain_cover", - { - "current_room_id": current_room_id, - "last_depth": last_depth, - "last_stream": last_stream, - "has_auth_chain_index": has_set_room_has_chain_index, - "finished": False, - }, - ) + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + }, + ) return total_rows_processed + + def _calculate_chain_cover_txn( + self, + txn: Cursor, + last_room_id: str, + last_depth: int, + last_stream: int, + batch_size: Optional[int], + single_room: bool, + ) -> _CalculateChainCover: + """Calculate the chain cover for `batch_size` events, ordered by + `(room_id, depth, stream)`. + + Args: + txn, + last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` + tuple to fetch results after. + batch_size: The maximum number of events to process. If None then + no limit. + single_room: Whether to calculate the index for just the given + room. + """ + + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("events.room_id", last_room_id), + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + extra_clause = "" + if single_room: + extra_clause = "AND events.room_id = ?" + tuple_args.append(last_room_id) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering, + events.room_id + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + %(extra)s + ORDER BY events.room_id, topological_ordering, stream_ordering + %(limit)s + """ % { + "tuple_cmp": tuple_clause, + "limit": "LIMIT ?" if batch_size is not None else "", + "extra": extra_clause, + } + + if batch_size is not None: + tuple_args.append(batch_size) + + txn.execute(sql, tuple_args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: row[5] for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + # Calculate the new last position we've processed up to. + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + new_last_room_id = rows[-1][5] if rows else "" # type: str + + # Map from room_id to last depth/stream_ordering processed for the room, + # excluding the last room (which we're likely still processing). We also + # need to include the room passed in if it's not included in the result + # set (as we then know we've processed all events in said room). + # + # This is the set of rooms that we can now safely flip the + # `has_auth_chain_index` bit for. + finished_rooms = { + row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id + } + if last_room_id not in finished_rooms and last_room_id != new_last_room_id: + finished_rooms[last_room_id] = (last_depth, last_stream) + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return _CalculateChainCover( + room_id=new_last_room_id, + depth=new_last_depth, + stream=new_last_stream, + processed_count=count, + finished_room_map=finished_rooms, + ) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index ff67a7374..0c46ad595 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from twisted.trial import unittest @@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def test_background_update(self): - """Test that the background update to calculate auth chains for historic - rooms works correctly. + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.requester = create_requester(self.user_id) + + def _generate_room(self) -> Tuple[str, List[Set[str]]]: + """Insert a room without a chain cover index. """ - - # Create a room - user_id = self.register_user("foo", "pass") - token = self.login("foo", "pass") - room_id = self.helper.create_room_as(user_id, tok=token) - requester = create_requester(user_id) - - store = self.hs.get_datastore() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index self.get_success( - store.db_pool.simple_update( + self.store.db_pool.simple_update( table="rooms", keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": False}, @@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Create a fork in the DAG with different events. event_handler = self.hs.get_event_creation_handler() - latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + latest_event_ids = self.get_success( + self.store.get_prev_events_for_room(room_id) + ) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state1 = list(self.get_success(context.get_current_state_ids()).values()) + state1 = set(self.get_success(context.get_current_state_ids()).values()) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state2 = list(self.get_success(context.get_current_state_ids()).values()) + state2 = set(self.get_success(context.get_current_state_ids()).values()) # Delete the chain cover info. @@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") - self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) + + return room_id, [state1, state2] + + def test_background_update_single_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() # Insert and run the background update. self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag - store.db_pool.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - store.db_pool.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Test that the `has_auth_chain_index` has been set - self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( - store.db_pool.runInteraction( + self.store.db_pool.runInteraction( "test", - store._get_auth_chain_difference_using_cover_index_txn, + self.store._get_auth_chain_difference_using_cover_index_txn, room_id, - [state1, state2], + states, ) ) + + def test_background_update_multiple_rooms(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + # Create a room + room_id1, states1 = self._generate_room() + room_id2, states2 = self._generate_room() + room_id3, states2 = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id1, + states1, + ) + ) + + def test_background_update_single_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create the rooms + room_id1, _ = self._generate_room() + room_id2, _ = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id1, event_type="m.test", body={"index": i}, tok=self.token + ) + + for i in range(0, 150): + self.helper.send_state( + room_id2, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))