diff --git a/changelog.d/12321.misc b/changelog.d/12321.misc new file mode 100644 index 000000000..200e7c44f --- /dev/null +++ b/changelog.d/12321.misc @@ -0,0 +1 @@ +Add ground work for speeding up device list updates for users in large numbers of rooms. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c38666da1..6324df883 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = { "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], + "device_lists_changes_in_room": ["converted_to_destinations"], } diff --git a/synapse/config/server.py b/synapse/config/server.py index 0f90302c9..b3a9e5075 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,14 @@ class ServerConfig(Config): config.get("use_account_validity_in_account_status") or False ) + # This is a temporary option that enables fully using the new + # `device_lists_changes_in_room` without the backwards compat code. This + # is primarily for testing. If enabled the server should *not* be + # downgraded, as it may lead to missing device list updates. + self.use_new_device_lists_changes_in_room = ( + config.get("use_new_device_lists_changes_in_room") or False + ) + self.rooms_to_exclude_from_sync: List[str] = ( config.get("exclude_rooms_from_sync") or [] ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d5ccaa0c3..c710c02cf 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,7 +37,10 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, StreamToken, @@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + # Whether `_handle_new_device_update_async` is currently processing. + self._handle_new_device_update_is_processing = False + + # If a new device update may have happened while the loop was + # processing. + self._handle_new_device_update_new_data = False + + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + + # Used to decide if we calculate outbound pokes up front or not. By + # default we do to allow safely downgrading Synapse. + self.use_new_device_lists_changes_in_room = ( + hs.config.server.use_new_device_lists_changes_in_room + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler): # No changes to notify about, so this is a no-op. return - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) + room_ids = await self.store.get_rooms_for_user(user_id) - hosts: Set[str] = set() - if self.hs.is_mine_id(user_id): - hosts.update(get_domain_from_id(u) for u in users_who_share_room) - hosts.discard(self.server_name) + hosts: Optional[Set[str]] = None + if not self.use_new_device_lists_changes_in_room: + hosts = set() - set_tag("target_hosts", hosts) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + joined_users = await self.store.get_users_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in joined_users) + + set_tag("target_hosts", hosts) + + hosts.discard(self.server_name) position = await self.store.add_device_change_to_streams( - user_id, device_ids, list(hosts) + user_id, + device_ids, + hosts=hosts, + room_ids=room_ids, ) if not position: @@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - users_to_notify = users_who_share_room.union({user_id}) + self.notifier.on_new_event( + "device_list_key", position, users={user_id}, rooms=room_ids + ) - self.notifier.on_new_event("device_list_key", position, users=users_to_notify) + # We may need to do some processing asynchronously. + self._handle_new_device_update_async() if hosts: logger.info( @@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} + @wrap_as_background_process("_handle_new_device_update_async") + async def _handle_new_device_update_async(self) -> None: + """Called when we have a new local device list update that we need to + send out over federation. + + This happens in the background so as not to block the original request + that generated the device update. + """ + if self._handle_new_device_update_is_processing: + self._handle_new_device_update_new_data = True + return + + self._handle_new_device_update_is_processing = True + + # The stream ID we processed previous iteration (if any), and the set of + # hosts we've already poked about for this update. This is so that we + # don't poke the same remote server about the same update repeatedly. + current_stream_id = None + hosts_already_sent_to: Set[str] = set() + + try: + while True: + self._handle_new_device_update_new_data = False + rows = await self.store.get_uncoverted_outbound_room_pokes() + if not rows: + # If the DB returned nothing then there is nothing left to + # do, *unless* a new device list update happened during the + # DB query. + if self._handle_new_device_update_new_data: + continue + else: + return + + for user_id, device_id, room_id, stream_id, opentracing_context in rows: + joined_user_ids = await self.store.get_users_in_room(room_id) + hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts.discard(self.server_name) + + # Check if we've already sent this update to some hosts + if current_stream_id == stream_id: + hosts -= hosts_already_sent_to + + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=stream_id, + hosts=hosts, + context=opentracing_context, + ) + + # Notify replication that we've updated the device list stream. + self.notifier.notify_replication() + + if hosts: + logger.info( + "Sending device list update notif for %r to: %r", + user_id, + hosts, + ) + for host in hosts: + self.federation_sender.send_device_messages( + host, immediate=False + ) + log_kv( + {"message": "sent device update to host", "host": host} + ) + + if current_stream_id != stream_id: + # Clear the set of hosts we've already sent to as we're + # processing a new update. + hosts_already_sent_to.clear() + + hosts_already_sent_to.update(hosts) + current_stream_id = stream_id + + finally: + self._handle_new_device_update_is_processing = False + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0ffd34f1d..f040e33bf 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) device_list_max = self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 1ea0b2aa6..cdbe3872f 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -146,6 +146,7 @@ class DataStore( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f08f7834d..07eea4b3d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index ea900e0f3..151f2aa9b 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68: new events. Changes in SCHEMA_VERSION = 69: + - We now write to `device_lists_changes_in_room` table. - Use sequence to generate future `application_services_txns.txn_id`s """ diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql new file mode 100644 index 000000000..b5b1782b2 --- /dev/null +++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql @@ -0,0 +1,38 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_lists_changes_in_room ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- This initially matches `device_lists_stream.stream_id`. Note that we + -- delete older values from `device_lists_stream`, so we can't use a foreign + -- constraint here. + -- + -- The table will contain rows with the same `stream_id` but different + -- `room_id`, as for each device update we store a row per room the user is + -- joined to. Therefore `(stream_id, room_id)` gives a unique index. + stream_id BIGINT NOT NULL, + + -- We have a background process which goes through this table and converts + -- entries into rows in `device_lists_outbound_pokes`. Once we have processed + -- a row, we mark it as such by setting `converted_to_destinations=TRUE`. + converted_to_destinations BOOLEAN NOT NULL, + opentracing_context TEXT +); + +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations; diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index e90592855..a6e91956a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import Mock +from parameterized import parameterized_class from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) +@parameterized_class( + [ + {"enable_room_poke_code_path": False}, + {"enable_room_poke_code_path": True}, + ] +) class FederationSenderDevicesTestCases(HomeserverTestCase): servlets = [ admin.register_servlets, @@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def default_config(self): c = super().default_config() c["send_federation"] = True + c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path return c def prepare(self, reactor, clock, hs): - # stub out get_users_who_share_room_with_user so that it claims that - # `@user2:host2` is in the room - def get_users_who_share_room_with_user(user_id): + # stub out `get_rooms_for_user` and `get_users_in_room` so that the + # server thinks the user shares a room with `@user2:host2` + def get_rooms_for_user(user_id): + return defer.succeed({"!room:host1"}) + + hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user + + def get_users_in_room(room_id): return defer.succeed({"@user2:host2"}) - hs.get_datastores().main.get_users_who_share_room_with_user = ( - get_users_who_share_room_with_user - ) + hs.get_datastores().main.get_users_in_room = get_users_in_room # whenever send_transaction is called, record the edu data self.edus = [] diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 21ffc5a90..d1227dd4a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add two device updates with sequential `stream_id`s self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get all device updates ever meant for this remote @@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase): "device_id5", ] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get device updates meant for this remote @@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add some more device updates to ensure it still resumes properly device_ids = ["device_id6", "device_id7"] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get the next batch of device updates @@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase): self.get_success( self.store.add_device_change_to_streams( - "@user_id:test", device_ids, ["somehost"] + "@user_id:test", device_ids, ["somehost"], ["!some:room"] ) )