From d8d0271977938d89585613e9a77537c33c4dc4a9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:39:27 +0100 Subject: [PATCH] Send device list updates to application services (MSC3202) - part 1 (#11881) Co-authored-by: Patrick Cloke --- changelog.d/11881.feature | 1 + synapse/appservice/__init__.py | 12 +- synapse/appservice/api.py | 10 +- synapse/appservice/scheduler.py | 53 ++++++- synapse/config/appservice.py | 1 + synapse/config/experimental.py | 5 +- synapse/handlers/appservice.py | 148 +++++++++++++++++- synapse/handlers/sync.py | 40 ++--- synapse/storage/databases/main/appservice.py | 14 +- synapse/storage/databases/main/devices.py | 48 ++++-- ...add_device_list_appservice_stream_type.sql | 23 +++ synapse/types.py | 25 +++ tests/appservice/test_scheduler.py | 54 +++++-- tests/handlers/test_appservice.py | 121 +++++++++++++- tests/storage/test_appservice.py | 17 +- 15 files changed, 490 insertions(+), 82 deletions(-) create mode 100644 changelog.d/11881.feature create mode 100644 synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql diff --git a/changelog.d/11881.feature b/changelog.d/11881.feature new file mode 100644 index 000000000..392294ffc --- /dev/null +++ b/changelog.d/11881.feature @@ -0,0 +1 @@ +Send device list changes to application services as specified by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/pull/3202), using unstable prefixes. The `msc3202_transaction_extensions` experimental homeserver config option must be enabled and `org.matrix.msc3202: true` must be present in the application service registration file for device list changes to be sent. The "left" field is currently always empty. \ No newline at end of file diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 07ec95f1d..d23d9221b 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +23,13 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.types import ( + DeviceListUpdates, + GroupID, + JsonDict, + UserID, + get_domain_from_id, +) from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -400,6 +407,7 @@ class AppServiceTransaction: to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ): self.service = service self.id = id @@ -408,6 +416,7 @@ class AppServiceTransaction: self.to_device_messages = to_device_messages self.one_time_key_counts = one_time_key_counts self.unused_fallback_keys = unused_fallback_keys + self.device_list_summary = device_list_summary async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -424,6 +433,7 @@ class AppServiceTransaction: to_device_messages=self.to_device_messages, one_time_key_counts=self.one_time_key_counts, unused_fallback_keys=self.unused_fallback_keys, + device_list_summary=self.device_list_summary, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 98fe35401..0cdbb04bf 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient): to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, ) -> bool: """ @@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: if one_time_key_counts: body[ @@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient): body[ "org.matrix.msc3202.device_unused_fallback_keys" ] = unused_fallback_keys + if device_list_summary: + body["org.matrix.msc3202.device_lists"] = { + "changed": list(device_list_summary.changed), + "left": list(device_list_summary.left), + } try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index a6084b9c3..3b49e6071 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -72,7 +72,7 @@ from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock if TYPE_CHECKING: @@ -122,6 +122,7 @@ class ApplicationServiceScheduler: events: Optional[Collection[EventBase]] = None, ephemeral: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonDict]] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Enqueue some data to be sent off to an application service. @@ -133,10 +134,18 @@ class ApplicationServiceScheduler: to_device_messages: The to-device messages to send. These differ from normal to-device messages sent to clients, as they have 'to_device_id' and 'to_user_id' fields. + device_list_summary: A summary of users that the application service either needs + to refresh the device lists of, or those that the application service need no + longer track the device lists of. """ # We purposefully allow this method to run with empty events/ephemeral # collections, so that callers do not need to check iterable size themselves. - if not events and not ephemeral and not to_device_messages: + if ( + not events + and not ephemeral + and not to_device_messages + and not device_list_summary + ): return if events: @@ -147,6 +156,10 @@ class ApplicationServiceScheduler: self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( to_device_messages ) + if device_list_summary: + self.queuer.queued_device_list_summaries.setdefault( + appservice.id, [] + ).append(device_list_summary) # Kick off a new application service transaction self.queuer.start_background_request(appservice) @@ -169,6 +182,8 @@ class _ServiceQueuer: self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # dict of {service_id: [to_device_message_json]} self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [device_list_summary]} + self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() @@ -212,7 +227,35 @@ class _ServiceQueuer: ] del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] - if not events and not ephemeral and not to_device_messages_to_send: + # Consolidate any pending device list summaries into a single, up-to-date + # summary. + # Note: this code assumes that in a single DeviceListUpdates, a user will + # never be in both "changed" and "left" sets. + device_list_summary = DeviceListUpdates() + for summary in self.queued_device_list_summaries.get(service.id, []): + # For every user in the incoming "changed" set: + # * Remove them from the existing "left" set if necessary + # (as we need to start tracking them again) + # * Add them to the existing "changed" set if necessary. + device_list_summary.left.difference_update(summary.changed) + device_list_summary.changed.update(summary.changed) + + # For every user in the incoming "left" set: + # * Remove them from the existing "changed" set if necessary + # (we no longer need to track them) + # * Add them to the existing "left" set if necessary. + device_list_summary.changed.difference_update(summary.left) + device_list_summary.left.update(summary.left) + self.queued_device_list_summaries.clear() + + if ( + not events + and not ephemeral + and not to_device_messages_to_send + # DeviceListUpdates is True if either the 'changed' or 'left' sets have + # at least one entry, otherwise False + and not device_list_summary + ): return one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None @@ -240,6 +283,7 @@ class _ServiceQueuer: to_device_messages_to_send, one_time_key_counts, unused_fallback_keys, + device_list_summary, ) except Exception: logger.exception("AS request failed") @@ -322,6 +366,7 @@ class _TransactionController: to_device_messages: Optional[List[JsonDict]] = None, one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -336,6 +381,7 @@ class _TransactionController: appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -345,6 +391,7 @@ class _TransactionController: to_device_messages=to_device_messages or [], one_time_key_counts=one_time_key_counts or {}, unused_fallback_keys=unused_fallback_keys or {}, + device_list_summary=device_list_summary or DeviceListUpdates(), ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 439bfe152..ada165f23 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -170,6 +170,7 @@ def _load_appservice( # When enabled, appservice transactions contain the following information: # - device One-Time Key counts # - device unused fallback key usage states + # - device list changes msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) if not isinstance(msc3202_transaction_extensions, bool): raise ValueError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 064db4487..d6bb1f752 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -59,8 +59,9 @@ class ExperimentalConfig(Config): "msc3202_device_masquerading", False ) - # Portion of MSC3202 related to transaction extensions: - # sending one-time key counts and fallback key usage to application services. + # The portion of MSC3202 related to transaction extensions: + # sending device list changes, one-time key counts and fallback key + # usage to application services. self.msc3202_transaction_extensions: bool = experimental.get( "msc3202_transaction_extensions", False ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bd913e524..316c4b677 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import ( + DeviceListUpdates, + JsonDict, + RoomAlias, + RoomStreamToken, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -58,6 +64,9 @@ class ApplicationServicesHandler: self._msc2409_to_device_messages_enabled = ( hs.config.experimental.msc2409_to_device_messages_enabled ) + self._msc3202_transaction_extensions_enabled = ( + hs.config.experimental.msc3202_transaction_extensions + ) self.current_max = 0 self.is_processing = False @@ -204,9 +213,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key" or - "to_device_key". Any other value for `stream_key` will cause this function - to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value for `stream_key` + will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -230,6 +239,7 @@ class ApplicationServicesHandler: "receipt_key", "presence_key", "to_device_key", + "device_list_key", ): return @@ -253,15 +263,37 @@ class ApplicationServicesHandler: ): return + # Ignore device lists if the feature flag is not enabled + if ( + stream_key == "device_list_key" + and not self._msc3202_transaction_extensions_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # # Note that whether these events are actually relevant to these appservices # is decided later on. + services = self.store.get_app_services() services = [ service - for service in self.store.get_app_services() - if service.supports_ephemeral + for service in services + # Different stream keys require different support booleans + if ( + stream_key + in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ) + and service.supports_ephemeral + ) + or ( + stream_key == "device_list_key" + and service.msc3202_transaction_extensions + ) ] if not services: # Bail out early if none of the target appservices have explicitly registered @@ -336,6 +368,20 @@ class ApplicationServicesHandler: service, "to_device", new_token ) + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -542,6 +588,96 @@ class ApplicationServicesHandler: return message_payload + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceListUpdates: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if await self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceListUpdates( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device list + updates of a given user. + + The application service is interested in the user's device list updates if any + of the following are true: + * The user is the appservice's sender localpart user. + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + # This method checks against both the sender localpart user as well as if the + # user is in the appservice's user namespace. + if appservice.is_interested_in_user(user_id): + return True + + # Determine whether any of the rooms the user is in justifies sending this + # device list update to the application service. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bceafca3b..303c38c74 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceListUpdates, JsonDict, MutableStateMap, Requester, @@ -184,21 +175,6 @@ class GroupsSyncResult: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -240,7 +216,7 @@ class SyncResult: knocked: List[KnockedSyncResult] archived: List[ArchivedSyncResult] to_device: List[JsonDict] - device_lists: DeviceLists + device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] groups: Optional[GroupsSyncResult] @@ -1264,8 +1240,8 @@ class SyncHandler: newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], - ) -> DeviceLists: - """Generate the DeviceLists section of sync + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync Args: sync_result_builder @@ -1383,9 +1359,11 @@ class SyncHandler: if any(e.room_id in joined_rooms for e in entries): newly_left_users.discard(user_id) - return DeviceLists(changed=users_that_have_changed, left=newly_left_users) + return DeviceListUpdates( + changed=users_that_have_changed, left=newly_left_users + ) else: - return DeviceLists(changed=[], left=[]) + return DeviceListUpdates() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index abea4383c..55e1ab099 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. @@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore( to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore( async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b7..f08f7834d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql new file mode 100644 index 000000000..7590e34b9 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what device list changes stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/types.py b/synapse/types.py index 5ce2a5b0a..500597b3a 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -25,6 +25,7 @@ from typing import ( Match, MutableMapping, Optional, + Set, Tuple, Type, TypeVar, @@ -748,6 +749,30 @@ class ReadReceipt: data: JsonDict +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + def get_verify_key_from_cross_signing_key(key_info): """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 1cbb05935..0b22afdc7 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,6 +24,7 @@ from synapse.appservice.scheduler import ( ) from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed @@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made self.assertEqual(1, self.recoverer.recover.call_count) # and invoked @@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_once_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( - service, [event2, event3], [], [], None, None + service, [event2, event3], [], [], None, None, DeviceListUpdates() ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv1, [srv_1_event], [], [], None, None, DeviceListUpdates() + ) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event], [], [], None, None, DeviceListUpdates() + ) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Expect the first event to be sent immediately. self.txn_ctrl.send.assert_called_with( - service, [event_list[0]], [], [], None, None + service, [event_list[0]], [], [], None, None, DeviceListUpdates() ) srv_1_defer.callback(service) # Then send the next 100 events self.txn_ctrl.send.assert_called_with( - service, event_list[1:101], [], [], None, None + service, event_list[1:101], [], [], None, None, DeviceListUpdates() ) srv_2_defer.callback(service) # Then the final 99 events self.txn_ctrl.send.assert_called_with( - service, event_list[101:], [], [], None, None + service, event_list[101:], [], [], None, None, DeviceListUpdates() ) self.assertEqual(3, self.txn_ctrl.send.call_count) @@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_multiple_ephemeral_no_queue(self): @@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_single_ephemeral_with_queue(self): @@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_1, [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [], None, None + service, + [], + event_list_2 + event_list_3, + [], + None, + None, + DeviceListUpdates(), ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], first_chunk, [], None, None + service, [], first_chunk, [], None, None, DeviceListUpdates() ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], second_chunk, [], None, None, DeviceListUpdates() + ) self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index cead9f90d..8c72cf6b3 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -15,6 +15,8 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock +from parameterized import parameterized + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): to_device_messages, _otks, _fbks, + _device_list_summary, ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent @@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): return appservice +class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends device list updates to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Allow us to modify cached feature flags mid-test + self.as_handler = hs.get_application_service_handler() + + # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that + # will be sent over the wire + self.put_json = simple_async_mock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) + + # Test across a variety of configuration values + @parameterized.expand( + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ] + ) + def test_application_service_receives_device_list_updates( + self, + experimental_feature_enabled: bool, + as_supports_txn_extensions: bool, + as_should_receive_device_list_updates: bool, + ): + """ + Tests that an application service receives notice of changed device + lists for a user, when a user changes their device lists. + + Arguments above are populated by parameterized. + + Args: + as_should_receive_device_list_updates: Whether we expect the AS to receive the + device list changes. + experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental + feature is enabled. This feature must be enabled for device lists to ASs to work. + as_supports_txn_extensions: Whether the application service has explicitly registered + to receive information defined by MSC3202 - which includes device list changes. + """ + # Change whether the experimental feature is enabled or disabled before making + # device list changes + self.as_handler._msc3202_transaction_extensions_enabled = ( + experimental_feature_enabled + ) + + # Create an appservice that is interested in "local_user" + appservice = ApplicationService( + token=random_string(10), + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@local_user:.+", + "exclusive": False, + } + ], + }, + supports_ephemeral=True, + msc3202_transaction_extensions=as_supports_txn_extensions, + # Must be set for Synapse to try pushing data to the AS + hs_token="abcde", + url="some_url", + ) + + # Register the application service + self._services.append(appservice) + + # Register a user on the homeserver + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login("local_user", "password") + + if as_should_receive_device_list_updates: + # Ensure that the resulting JSON uses the unstable prefix and contains the + # expected users + self.put_json.assert_called_once() + json_body = self.put_json.call_args[1]["json_body"] + + # Our application service should have received a device list update with + # "local_user" in the "changed" list + device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {}) + self.assertEqual([], device_list_dict["left"]) + self.assertEqual([self.local_user], device_list_dict["changed"]) + + else: + # No device list changes should have been sent out + self.put_json.assert_not_called() + + class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Argument indices for pulling out arguments from a `send_mock`. ARG_OTK_COUNTS = 4 diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 6249eb8c1..97de9a59e 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) ) self.assertEqual(txn.id, 1) @@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9646) self.assertEqual(txn.events, events) @@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events) @@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events)