diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d59685241..cdc078cf1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -775,6 +775,9 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 01a4f8588..23b1650e4 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -72,6 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def _invalidate_caches_for_devices(self, token, rows): for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity,)) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 06e1d9f03..4c19c02bb 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple from six import iteritems @@ -31,7 +32,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import Database, LoggingTransaction from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -574,10 +575,12 @@ class DeviceWorkerStore(SQLBaseStore): else: return set() - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. + async def get_all_device_list_changes_for_remotes( + self, from_key: int, to_key: int + ) -> List[Tuple[int, str]]: + """Return a list of `(stream_id, entity)` which is the combined list of + changes to devices and which destinations need to be poked. Entity is + either a user ID (starting with '@') or a remote destination. """ # This query Does The Right Thing where it'll correctly apply the @@ -591,7 +594,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? """ - return self.db.execute( + return await self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -1018,11 +1021,19 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] - def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids): + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) + min_stream_id = stream_ids[0] + # Delete older entries in the table, as we really only care about # when the latest change happened. txn.executemany( @@ -1030,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_ids[0]) for device_id in device_ids], + [(user_id, device_id, min_stream_id) for device_id in device_ids], ) self.db.simple_insert_many_txn(