Send to-device messages to application services (#11215)

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
This commit is contained in:
Andrew Morgan 2022-02-01 14:13:38 +00:00 committed by GitHub
parent b7282fe7d1
commit 64ec45fc1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 856 additions and 162 deletions

View File

@ -0,0 +1 @@
Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default.

View File

@ -351,11 +351,13 @@ class AppServiceTransaction:
id: int, id: int,
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
): ):
self.service = service self.service = service
self.id = id self.id = id
self.events = events self.events = events
self.ephemeral = ephemeral self.ephemeral = ephemeral
self.to_device_messages = to_device_messages
async def send(self, as_api: "ApplicationServiceApi") -> bool: async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface. """Sends this transaction using the provided AS API interface.
@ -369,6 +371,7 @@ class AppServiceTransaction:
service=self.service, service=self.service,
events=self.events, events=self.events,
ephemeral=self.ephemeral, ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages,
txn_id=self.id, txn_id=self.id,
) )

View File

@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient):
service: "ApplicationService", service: "ApplicationService",
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
txn_id: Optional[int] = None, txn_id: Optional[int] = None,
) -> bool: ) -> bool:
"""
Push data to an application service.
Args:
service: The application service to send to.
events: The persistent events to send.
ephemeral: The ephemeral events to send.
to_device_messages: The to-device messages to send.
txn_id: An unique ID to assign to this transaction. Application services should
deduplicate transactions received with identitical IDs.
Returns:
True if the task succeeded, False if it failed.
"""
if service.url is None: if service.url is None:
return True return True
@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it # Never send ephemeral events to appservices that do not support it
body: Dict[str, List[JsonDict]] = {"events": serialized_events}
if service.supports_ephemeral: if service.supports_ephemeral:
body = { body.update(
"events": serialized_events, {
# TODO: Update to stable prefixes once MSC2409 completes FCP merge.
"de.sorunome.msc2409.ephemeral": ephemeral, "de.sorunome.msc2409.ephemeral": ephemeral,
"de.sorunome.msc2409.to_device": to_device_messages,
} }
else: )
body = {"events": serialized_events}
try: try:
await self.put_json( await self.put_json(

View File

@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
List,
Optional,
Set,
)
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
# Maximum number of ephemeral events to provide in an AS transaction. # Maximum number of ephemeral events to provide in an AS transaction.
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
# Maximum number of to-device messages to provide in an AS transaction.
MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
class ApplicationServiceScheduler: class ApplicationServiceScheduler:
"""Public facing API for this module. Does the required DI to tie the """Public facing API for this module. Does the required DI to tie the
@ -97,15 +109,40 @@ class ApplicationServiceScheduler:
for service in services: for service in services:
self.txn_ctrl.start_recoverer(service) self.txn_ctrl.start_recoverer(service)
def submit_event_for_as( def enqueue_for_appservice(
self, service: ApplicationService, event: EventBase self,
appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
) -> None: ) -> None:
self.queuer.enqueue_event(service, event) """
Enqueue some data to be sent off to an application service.
def submit_ephemeral_events_for_as( Args:
self, service: ApplicationService, events: List[JsonDict] appservice: The application service to create and send a transaction to.
) -> None: events: The persistent room events to send.
self.queuer.enqueue_ephemeral(service, events) ephemeral: The ephemeral events to send.
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages:
return
if events:
self.queuer.queued_events.setdefault(appservice.id, []).extend(events)
if ephemeral:
self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral)
if to_device_messages:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
class _ServiceQueuer: class _ServiceQueuer:
@ -121,13 +158,15 @@ class _ServiceQueuer:
self.queued_events: Dict[str, List[EventBase]] = {} self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]} # dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {} self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# the appservices which currently have a transaction in flight # the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set() self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock self.clock = clock
def _start_background_request(self, service: ApplicationService) -> None: def start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one # start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight: if service.id in self.requests_in_flight:
return return
@ -136,16 +175,6 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service "as-sender-%s" % (service.id,), self._send_request, service
) )
def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
def enqueue_ephemeral(
self, service: ApplicationService, events: List[JsonDict]
) -> None:
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
async def _send_request(self, service: ApplicationService) -> None: async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender # sanity-check: we shouldn't get here if this service already has a sender
# running. # running.
@ -162,11 +191,21 @@ class _ServiceQueuer:
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
if not events and not ephemeral: all_to_device_messages = self.queued_to_device_messages.get(
service.id, []
)
to_device_messages_to_send = all_to_device_messages[
:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send:
return return
try: try:
await self.txn_ctrl.send(service, events, ephemeral) await self.txn_ctrl.send(
service, events, ephemeral, to_device_messages_to_send
)
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
finally: finally:
@ -198,10 +237,24 @@ class _TransactionController:
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None, ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
) -> None: ) -> None:
"""
Create a transaction with the given data and send to the provided
application service.
Args:
service: The application service to send the transaction to.
events: The persistent events to include in the transaction.
ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction.
"""
try: try:
txn = await self.store.create_appservice_txn( txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral or [] service=service,
events=events,
ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [],
) )
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:

View File

@ -52,3 +52,10 @@ class ExperimentalConfig(Config):
self.msc3202_device_masquerading_enabled: bool = experimental.get( self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False "msc3202_device_masquerading", False
) )
# MSC2409 (this setting only relates to optionally sending to-device messages).
# Presence, typing and read receipt EDUs are already sent to application services that
# have opted in to receive them. If enabled, this adds to-device messages to that list.
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

View File

@ -55,6 +55,9 @@ class ApplicationServicesHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notify_appservices = hs.config.appservice.notify_appservices self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self.current_max = 0 self.current_max = 0
self.is_processing = False self.is_processing = False
@ -132,7 +135,9 @@ class ApplicationServicesHandler:
# Fork off pushes to these services # Fork off pushes to these services
for service in services: for service in services:
self.scheduler.submit_event_for_as(service, event) self.scheduler.enqueue_for_appservice(
service, events=[event]
)
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
@ -199,8 +204,9 @@ class ApplicationServicesHandler:
Args: Args:
stream_key: The stream the event came from. stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other `stream_key` can be "typing_key", "receipt_key", "presence_key" or
value for `stream_key` will cause this function to return early. "to_device_key". Any other value for `stream_key` will cause this function
to return early.
Ephemeral events will only be pushed to appservices that have opted into Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration receiving them by setting `push_ephemeral` to true in their registration
@ -216,8 +222,15 @@ class ApplicationServicesHandler:
if not self.notify_appservices: if not self.notify_appservices:
return return
# Ignore any unsupported streams # Notify appservices of updates in ephemeral event streams.
if stream_key not in ("typing_key", "receipt_key", "presence_key"): # Only the following streams are currently supported.
# FIXME: We should use constants for these values.
if stream_key not in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
):
return return
# Assert that new_token is an integer (and not a RoomStreamToken). # Assert that new_token is an integer (and not a RoomStreamToken).
@ -233,6 +246,13 @@ class ApplicationServicesHandler:
# Additional context: https://github.com/matrix-org/synapse/pull/11137 # Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int) assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == "to_device_key"
and not self._msc2409_to_device_messages_enabled
):
return
# Check whether there are any appservices which have registered to receive # Check whether there are any appservices which have registered to receive
# ephemeral events. # ephemeral events.
# #
@ -266,7 +286,7 @@ class ApplicationServicesHandler:
with Measure(self.clock, "notify_interested_services_ephemeral"): with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services: for service in services:
if stream_key == "typing_key": if stream_key == "typing_key":
# Note that we don't persist the token (via set_type_stream_id_for_appservice) # Note that we don't persist the token (via set_appservice_stream_type_pos)
# for typing_key due to performance reasons and due to their highly # for typing_key due to performance reasons and due to their highly
# ephemeral nature. # ephemeral nature.
# #
@ -274,7 +294,7 @@ class ApplicationServicesHandler:
# and, if they apply to this application service, send it off. # and, if they apply to this application service, send it off.
events = await self._handle_typing(service, new_token) events = await self._handle_typing(service, new_token)
if events: if events:
self.scheduler.submit_ephemeral_events_for_as(service, events) self.scheduler.enqueue_for_appservice(service, ephemeral=events)
continue continue
# Since we read/update the stream position for this AS/stream # Since we read/update the stream position for this AS/stream
@ -285,26 +305,35 @@ class ApplicationServicesHandler:
): ):
if stream_key == "receipt_key": if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token) events = await self._handle_receipts(service, new_token)
if events: self.scheduler.enqueue_for_appservice(service, ephemeral=events)
self.scheduler.submit_ephemeral_events_for_as(
service, events
)
# Persist the latest handled stream token for this appservice # Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice( await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token service, "read_receipt", new_token
) )
elif stream_key == "presence_key": elif stream_key == "presence_key":
events = await self._handle_presence(service, users, new_token) events = await self._handle_presence(service, users, new_token)
if events: self.scheduler.enqueue_for_appservice(service, ephemeral=events)
self.scheduler.submit_ephemeral_events_for_as(
service, events # Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "presence", new_token
)
elif stream_key == "to_device_key":
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
service, new_token, users
)
self.scheduler.enqueue_for_appservice(
service, to_device_messages=to_device_messages
) )
# Persist the latest handled stream token for this appservice # Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice( await self.store.set_appservice_stream_type_pos(
service, "presence", new_token service, "to_device", new_token
) )
async def _handle_typing( async def _handle_typing(
@ -440,6 +469,79 @@ class ApplicationServicesHandler:
return events return events
async def _get_to_device_messages(
self,
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
) -> List[JsonDict]:
"""
Given an application service, determine which events it should receive
from those between the last-recorded to-device message stream token for this
appservice and the given stream token.
Args:
service: The application service to check for which events it should receive.
new_token: The latest to-device event stream token.
users: The users to be notified for the new to-device messages
(ie, the recipients of the messages).
Returns:
A list of JSON dictionaries containing data derived from the to-device events
that should be sent to the given application service.
"""
# Get the stream token that this application service has processed up until
from_key = await self.store.get_type_stream_id_for_appservice(
service, "to_device"
)
# Filter out users that this appservice is not interested in
users_appservice_is_interested_in: List[str] = []
for user in users:
# FIXME: We should do this farther up the call stack. We currently repeat
# this operation in _handle_presence.
if isinstance(user, UserID):
user = user.to_string()
if service.is_interested_in_user(user):
users_appservice_is_interested_in.append(user)
if not users_appservice_is_interested_in:
# Return early if the AS was not interested in any of these users
return []
# Retrieve the to-device messages for each user
recipient_device_to_messages = await self.store.get_messages_for_user_devices(
users_appservice_is_interested_in,
from_key,
new_token,
)
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
# to the event JSON so that the application service will know which user/device
# combination this messages was intended for.
#
# So we mangle this dict into a flat list of to-device messages with the relevant
# user ID and device ID embedded inside each message dict.
message_payload: List[JsonDict] = []
for (
user_id,
device_id,
), messages in recipient_device_to_messages.items():
for message_json in messages:
# Remove 'message_id' from the to-device message, as it's an internal ID
message_json.pop("message_id", None)
message_payload.append(
{
"to_user_id": user_id,
"to_device_id": device_id,
**message_json,
}
)
return message_payload
async def query_user_exists(self, user_id: str) -> bool: async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists. """Check if any application service knows this user_id exists.

View File

@ -1348,8 +1348,8 @@ class SyncHandler:
if sync_result_builder.since_token is not None: if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key) since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id != int(now_token.to_device_key): if device_id is not None and since_stream_id != int(now_token.to_device_key):
messages, stream_id = await self.store.get_new_messages_for_device( messages, stream_id = await self.store.get_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )

View File

@ -461,7 +461,9 @@ class Notifier:
users, users,
) )
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception(
"Error notifying application services of ephemeral events"
)
def on_new_replication_data(self) -> None: def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened """Used to inform replication listeners that something has happened

View File

@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
) -> AppServiceTransaction: ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service """Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the with the given list of events. Ephemeral events are NOT persisted to the
@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The service who the transaction is for. service: The service who the transaction is for.
events: A list of persistent events to put in the transaction. events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
Returns: Returns:
A new transaction. A new transaction.
@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
(service.id, new_txn_id, event_ids), (service.id, new_txn_id, event_ids),
) )
return AppServiceTransaction( return AppServiceTransaction(
service=service, id=new_txn_id, events=events, ephemeral=ephemeral service=service,
id=new_txn_id,
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return AppServiceTransaction( return AppServiceTransaction(
service=service, id=entry["txn_id"], events=events, ephemeral=[] service=service,
id=entry["txn_id"],
events=events,
ephemeral=[],
to_device_messages=[],
) )
def _get_last_txn(self, txn, service_id: Optional[str]) -> int: def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice( async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str self, service: ApplicationService, type: str
) -> int: ) -> int:
if type not in ("read_receipt", "presence"): if type not in ("read_receipt", "presence", "to_device"):
raise ValueError( raise ValueError(
"Expected type to be a valid application stream id type, got %s" "Expected type to be a valid application stream id type, got %s"
% (type,) % (type,)
@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
) )
async def set_type_stream_id_for_appservice( async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int] self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None: ) -> None:
if stream_type not in ("read_receipt", "presence"): if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError( raise ValueError(
"Expected type to be a valid application stream id type, got %s" "Expected type to be a valid application stream id type, got %s"
% (stream_type,) % (stream_type,)
) )
def set_type_stream_id_for_appservice_txn(txn): def set_appservice_stream_type_pos_txn(txn):
stream_id_type = "%s_stream_id" % stream_type stream_id_type = "%s_stream_id" % stream_type
txn.execute( txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?" "UPDATE application_services_state SET %s = ? WHERE as_id=?"
@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
) )

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from synapse.logging import issue9533_logger from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
@ -24,6 +24,7 @@ from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
async def get_new_messages_for_device( async def get_messages_for_user_devices(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
) -> Dict[Tuple[str, str], List[JsonDict]]:
"""
Retrieve to-device messages for a given set of users.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
Returns:
A dictionary of (user id, device id) -> list of to-device messages.
"""
# We expect the stream ID returned by _get_device_messages to always
# be to_stream_id. So, no need to return it from this function.
(
user_id_device_id_to_messages,
last_processed_stream_id,
) = await self._get_device_messages(
user_ids=user_ids,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
)
assert (
last_processed_stream_id == to_stream_id
), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
return user_id_device_id_to_messages
async def get_messages_for_device(
self, self,
user_id: str, user_id: str,
device_id: Optional[str], device_id: str,
last_stream_id: int, from_stream_id: int,
current_stream_id: int, to_stream_id: int,
limit: int = 100, limit: int = 100,
) -> Tuple[List[dict], int]: ) -> Tuple[List[JsonDict], int]:
""" """
Retrieve to-device messages for a single user device.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Args: Args:
user_id: The recipient user_id. user_id: The ID of the user to retrieve messages for.
device_id: The recipient device_id. device_id: The ID of the device to retrieve to-device messages for.
last_stream_id: The last stream ID checked. from_stream_id: The lower boundary of stream id to filter with (exclusive).
current_stream_id: The current position of the to device to_stream_id: The upper boundary of stream id to filter with (inclusive).
message stream. limit: A limit on the number of to-device messages returned.
limit: The maximum number of messages to retrieve.
Returns: Returns:
A tuple containing: A tuple containing:
* A list of messages for the device. * A list of to-device messages within the given stream id range intended for
* The max stream token of these messages. There may be more to retrieve the given user / device combo.
if the given limit was reached. * The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed( (
user_id, last_stream_id user_id_device_id_to_messages,
) last_processed_stream_id,
if not has_changed: ) = await self._get_device_messages(
return [], current_stream_id user_ids=[user_id],
device_id=device_id,
def get_new_messages_for_device_txn(txn): from_stream_id=from_stream_id,
sql = ( to_stream_id=to_stream_id,
"SELECT stream_id, message_json FROM device_inbox" limit=limit,
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
) )
messages = [] if not user_id_device_id_to_messages:
stream_pos = current_stream_id # There were no messages!
return [], to_stream_id
# Extract the messages, no need to return the user and device ID again
to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
return to_device_messages, last_processed_stream_id
async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
device_id: Optional[str] = None,
limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that a stream ID can be shared by multiple copies of the same message with
different recipient devices. Stream IDs are only unique in the context of a single
user ID / device ID pair. Thus, applying a limit (of messages to return) when working
with a sliding window of stream IDs is only possible when querying messages of a
single user device.
Finally, note that device IDs are not unique across users.
Args:
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
device_id: A device ID to query to-device messages for. If not provided, to-device
messages from all device IDs for the given user IDs will be queried. May not be
provided if `user_ids` contains more than one entry.
limit: The maximum number of to-device messages to return. Can only be used when
passing a single user ID / device ID tuple.
Returns:
A tuple containing:
* A dict of (user_id, device_id) -> list of to-device messages
* The last-processed stream ID. If this is less than `to_stream_id`, then
there may be more messages to retrieve. If `limit` is not set, then this
is always equal to 'to_stream_id'.
"""
if not user_ids:
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id
# Prevent a query for one user's device also retrieving another user's device with
# the same device ID (device IDs are not unique across users).
if len(user_ids) > 1 and device_id is not None:
raise AssertionError(
"Programming error: 'device_id' cannot be supplied to "
"_get_device_messages when >1 user_id has been provided"
)
# A limit can only be applied when querying for a single user ID / device ID tuple.
# See the docstring of this function for more details.
if limit is not None and device_id is None:
raise AssertionError(
"Programming error: _get_device_messages was passed 'limit' "
"without a specific user_id/device_id"
)
user_ids_to_query: Set[str] = set()
device_ids_to_query: Set[str] = set()
# Note that a device ID could be an empty str
if device_id is not None:
# If a device ID was passed, use it to filter results.
# Otherwise, device IDs will be derived from the given collection of user IDs.
device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
# This user has new messages sent to them. Query messages for them
user_ids_to_query.add(user_id)
def get_device_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
# If a list of device IDs was not provided, retrieve all devices IDs
# for the given users. We explicitly do not query hidden devices, as
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"user_id": user_id, "hidden": False},
retcols=("device_id",),
)
device_ids_to_query.update(
{row["device_id"] for row in user_device_dicts}
)
if not device_ids_to_query:
# We've ended up with no devices to query.
return {}, to_stream_id
# We include both user IDs and device IDs in this query, as we have an index
# (device_inbox_user_stream_id) for them.
user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids_to_query
)
(
device_id_many_clause_sql,
device_id_many_clause_args,
) = make_in_list_sql_clause(
self.database_engine, "device_id", device_ids_to_query
)
sql = f"""
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
WHERE {user_id_many_clause_sql}
AND {device_id_many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
sql_args = (
*user_id_many_clause_args,
*device_id_many_clause_args,
from_stream_id,
to_stream_id,
)
# If a limit was provided, limit the data retrieved from the database
if limit is not None:
sql += "LIMIT ?"
sql_args += (limit,)
txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
for row in txn: for row in txn:
stream_pos = row[0] last_processed_stream_pos = row[0]
messages.append(db_to_json(row[1])) recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
# If the limit was not reached we know that there's no more data for this # Store the device details
# user/device pair up to current_stream_id. recipient_device_to_messages.setdefault(
if len(messages) < limit: (recipient_user_id, recipient_device_id), []
stream_pos = current_stream_id ).append(message_dict)
return messages, stream_pos if limit is not None and txn.rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return recipient_device_to_messages, last_processed_stream_pos
# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn "get_device_messages", get_device_messages_txn
) )
@trace @trace

View File

@ -0,0 +1,21 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what to_device stream id that this application
-- service has been caught up to.
-- NULL indicates that this appservice has never received any to_device messages. This
-- can be used, for example, to avoid sending a huge dump of messages at startup.
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;

View File

@ -11,23 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.appservice.scheduler import ( from synapse.appservice.scheduler import (
ApplicationServiceScheduler,
_Recoverer, _Recoverer,
_ServiceQueuer,
_TransactionController, _TransactionController,
) )
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from ..utils import MockClock from ..utils import MockClock
if TYPE_CHECKING:
from twisted.internet.testing import MemoryReactor
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[] # txn made and saved service=service,
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
) )
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed txn.complete.assert_called_once_with(self.store) # txn completed
@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[] # txn made and saved service=service,
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
) )
self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed self.assertEquals(0, txn.complete.call_count) # or completed
@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[] service=service, events=events, ephemeral=[], to_device_messages=[]
) )
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer) self.callback.assert_called_once_with(self.recoverer)
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
def setUp(self): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer):
self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock() self.txn_ctrl.send = simple_async_mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
# Replace instantiated _TransactionController instances with our Mock
self.scheduler.txn_ctrl = self.txn_ctrl
self.scheduler.queuer.txn_ctrl = self.txn_ctrl
def test_send_single_event_no_queue(self): def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4) service = Mock(id=4)
event = Mock() event = Mock()
self.queuer.enqueue_event(service, event) self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], []) self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock( self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4) service = Mock(id=4)
event = Mock(event_id="first") event = Mock(event_id="first")
event2 = Mock(event_id="second") event2 = Mock(event_id="second")
event3 = Mock(event_id="third") event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet. # Send an event and don't resolve it just yet.
self.queuer.enqueue_event(service, event) self.scheduler.enqueue_for_appservice(service, events=[event])
# Send more events: expect send() to NOT be called multiple times. # Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_event(service, event2) # (call enqueue_for_appservice multiple times deliberately)
self.queuer.enqueue_event(service, event3) self.scheduler.enqueue_for_appservice(service, events=[event2])
self.txn_ctrl.send.assert_called_with(service, [event], []) self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [])
self.assertEquals(1, self.txn_ctrl.send.call_count) self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent # Resolve the send event: expect the queued events to be sent
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self): def test_multiple_service_queues(self):
@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer] send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z): def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent
self.queuer.enqueue_event(srv1, srv_1_event) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.queuer.enqueue_event(srv1, srv_1_event2) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
self.queuer.enqueue_event(srv2, srv_2_event) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.queuer.enqueue_event(srv2, srv_2_event2) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
# make sure callbacks for a service only send queued events for THAT # make sure callbacks for a service only send queued events for THAT
# service # service
srv_2_defer.callback(srv2) srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self): def test_send_large_txns(self):
@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
srv_2_defer = defer.Deferred() srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer] send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z): def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)] event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list: for event in event_list:
self.queuer.enqueue_event(service, event) self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately. # Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], []) self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
srv_1_defer.callback(service) srv_1_defer.callback(service)
# Then send the next 100 events # Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], []) self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
srv_2_defer.callback(service) srv_2_defer.callback(service)
# Then the final 99 events # Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], []) self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self): def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event")] event_list = [Mock(name="event")]
self.queuer.enqueue_ephemeral(service, event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list) self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_multiple_ephemeral_no_queue(self): def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.queuer.enqueue_ephemeral(service, event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list) self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_single_ephemeral_with_queue(self): def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock( self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4) service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet. # Send an event and don't resolve it just yet.
self.queuer.enqueue_ephemeral(service, event_list_1) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1)
# Send more events: expect send() to NOT be called multiple times. # Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_ephemeral(service, event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.queuer.enqueue_ephemeral(service, event_list_3) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1) self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
self.assertEquals(1, self.txn_ctrl.send.call_count) self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send # Resolve txn_ctrl.send
d.callback(service) d.callback(service)
# Expect the queued events to be sent # Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, []
)
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self): def test_send_large_txns_ephemeral(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock( self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)] first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk event_list = first_chunk + second_chunk
self.queuer.enqueue_ephemeral(service, event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk) self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk) self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)

View File

@ -1,4 +1,4 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.admin
import synapse.storage
from synapse.appservice import ApplicationService
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.stringutils import random_string
from tests.test_utils import make_awaitable from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
from tests.utils import MockClock from tests.utils import MockClock
from .. import unittest
class AppServiceHandlerTestCase(unittest.TestCase): class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler.""" """Tests the ApplicationServicesHandler."""
@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore.return_value = self.mock_store hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
None
)
hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
] ]
self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.submit_event_for_as.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, event interested_service, events=[event]
) )
def test_query_user_exists_unknown_user(self): def test_query_user_exists_unknown_user(self):
@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
""" """
interested_service = self._mkservice(is_interested=True) interested_service = self._mkservice(is_interested=True)
services = [interested_service] services = [interested_service]
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579 579
@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"] "receipt_key", 580, ["@fakerecipient:example.com"]
) )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, [event] interested_service, ephemeral=[event]
) )
self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with( self.mock_store.set_appservice_stream_type_pos.assert_called_once_with(
interested_service, interested_service,
"read_receipt", "read_receipt",
580, 580,
@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"] "receipt_key", 580, ["@fakerecipient:example.com"]
) )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() # This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[]
)
def _mkservice(self, is_interested, protocols=None): def _mkservice(self, is_interested, protocols=None):
service = Mock() service = Mock()
@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token" service.token = "mock_service_token"
service.url = "mock_service_url" service.url = "mock_service_url"
return service return service
class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"""
Tests that the ApplicationServicesHandler sends events to application
services correctly.
"""
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
sendtodevice.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
# A user on the homeserver.
self.local_user_device_id = "local_device"
self.local_user = self.register_user("local_user", "password")
self.local_user_token = self.login(
"local_user", "password", self.local_user_device_id
)
# A user on the homeserver which lies within an appservice's exclusive user namespace.
self.exclusive_as_user_device_id = "exclusive_as_device"
self.exclusive_as_user = self.register_user("exclusive_as_user", "password")
self.exclusive_as_user_token = self.login(
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_local_to_device(self):
"""
Test that when a user sends a to-device message to another user
that is an application service's user namespace, the
application service will receive it.
"""
interested_appservice = self._register_application_service(
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@exclusive_as_user:.+",
"exclusive": True,
}
],
},
)
# Have local_user send a to-device message to exclusive_as_user
message_content = {"some_key": "some really interesting value"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={
"messages": {
self.exclusive_as_user: {
self.exclusive_as_user_device_id: message_content
}
}
},
access_token=self.local_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
# Have exclusive_as_user send a to-device message to local_user
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
content={
"messages": {
self.local_user: {self.local_user_device_id: message_content}
}
},
access_token=self.exclusive_as_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
# Check if our application service - that is interested in exclusive_as_user - received
# the to-device message as part of an AS transaction.
# Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
#
# The uninterested application service should not have been notified at all.
self.send_mock.assert_called_once()
service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
self.assertEqual(service, interested_appservice)
self.assertEqual(to_device_messages[0]["type"], "m.room_key_request")
self.assertEqual(to_device_messages[0]["sender"], self.local_user)
# Additional fields 'to_user_id' and 'to_device_id' specifically for
# to-device messages via the AS API
self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
self.assertEqual(
to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id
)
self.assertEqual(to_device_messages[0]["content"], message_content)
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_bursts_of_to_device(self):
"""
Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions.
Also tests that uninterested application services do not receive messages.
"""
# Register two application services with exclusive interest in a user
interested_appservices = []
for _ in range(2):
appservice = self._register_application_service(
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@exclusive_as_user:.+",
"exclusive": True,
}
],
},
)
interested_appservices.append(appservice)
# ...and an application service which does not have any user interest.
self._register_application_service()
to_device_message_content = {
"some key": "some interesting value",
}
# We need to send a large burst of to-device messages. We also would like to
# include them all in the same application service transaction so that we can
# test large transactions.
#
# To do this, we can send a single to-device message to many user devices at
# once.
#
# We insert number_of_messages - 1 messages into the database directly. We'll then
# send a final to-device message to the real device, which will also kick off
# an AS transaction (as just inserting messages into the DB won't).
number_of_messages = 150
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
messages = {
self.exclusive_as_user: {
device_id: to_device_message_content for device_id in fake_device_ids
}
}
# Create a fake device per message. We can't send to-device messages to
# a device that doesn't exist.
self.get_success(
self.hs.get_datastore().db_pool.simple_insert_many(
desc="test_application_services_receive_burst_of_to_device",
table="devices",
keys=("user_id", "device_id"),
values=[
(
self.exclusive_as_user,
device_id,
)
for device_id in fake_device_ids
],
)
)
# Seed the device_inbox table with our fake messages
self.get_success(
self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
)
# Now have local_user send a final to-device message to exclusive_as_user. All unsent
# to-device messages should be sent to any application services
# interested in exclusive_as_user.
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
content={
"messages": {
self.exclusive_as_user: {
self.exclusive_as_user_device_id: to_device_message_content
}
}
},
access_token=self.local_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
self.send_mock.assert_called()
# Count the total number of to-device messages that were sent out per-service.
# Ensure that we only sent to-device messages to interested services, and that
# each interested service received the full count of to-device messages.
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
# Add to the count of messages for this application service
service_id_to_message_count.setdefault(service.id, 0)
service_id_to_message_count[service.id] += len(to_device_messages)
# Assert that each interested service received the full count of messages
for count in service_id_to_message_count.values():
self.assertEqual(count, number_of_messages)
def _register_application_service(
self,
namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
) -> ApplicationService:
"""
Register a new application service, with the given namespaces of interest.
Args:
namespaces: A dictionary containing any user, room or alias namespaces that
the application service is interested in.
Returns:
The registered application service.
"""
# Create an application service
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces=namespaces,
supports_ephemeral=True,
)
# Register the application service
self._services.append(appservice)
return appservice

View File

@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success( txn = self.get_success(
defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [])
)
) )
self.assertEquals(txn.id, 1) self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events)) self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(self.store.create_appservice_txn(service, events, [])) txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9646) self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643)) self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(self.store.create_appservice_txn(service, events, [])) txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(self.store.create_appservice_txn(service, events, [])) txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
ValueError, ValueError,
) )
def test_set_type_stream_id_for_appservice(self) -> None: def test_set_appservice_stream_type_pos(self) -> None:
read_receipt_value = 1024 read_receipt_value = 1024
self.get_success( self.get_success(
self.store.set_type_stream_id_for_appservice( self.store.set_appservice_stream_type_pos(
self.service, "read_receipt", read_receipt_value self.service, "read_receipt", read_receipt_value
) )
) )
@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
self.assertEqual(result, read_receipt_value) self.assertEqual(result, read_receipt_value)
self.get_success( self.get_success(
self.store.set_type_stream_id_for_appservice( self.store.set_appservice_stream_type_pos(
self.service, "presence", read_receipt_value self.service, "presence", read_receipt_value
) )
) )
@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
) )
self.assertEqual(result, read_receipt_value) self.assertEqual(result, read_receipt_value)
def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: def test_set_appservice_stream_type_pos_invalid_type(self) -> None:
self.get_failure( self.get_failure(
self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024),
ValueError, ValueError,
) )