mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Send to-device messages to application services (#11215)
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
This commit is contained in:
parent
b7282fe7d1
commit
64ec45fc1b
1
changelog.d/11215.feature
Normal file
1
changelog.d/11215.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default.
|
@ -351,11 +351,13 @@ class AppServiceTransaction:
|
|||||||
id: int,
|
id: int,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: List[JsonDict],
|
ephemeral: List[JsonDict],
|
||||||
|
to_device_messages: List[JsonDict],
|
||||||
):
|
):
|
||||||
self.service = service
|
self.service = service
|
||||||
self.id = id
|
self.id = id
|
||||||
self.events = events
|
self.events = events
|
||||||
self.ephemeral = ephemeral
|
self.ephemeral = ephemeral
|
||||||
|
self.to_device_messages = to_device_messages
|
||||||
|
|
||||||
async def send(self, as_api: "ApplicationServiceApi") -> bool:
|
async def send(self, as_api: "ApplicationServiceApi") -> bool:
|
||||||
"""Sends this transaction using the provided AS API interface.
|
"""Sends this transaction using the provided AS API interface.
|
||||||
@ -369,6 +371,7 @@ class AppServiceTransaction:
|
|||||||
service=self.service,
|
service=self.service,
|
||||||
events=self.events,
|
events=self.events,
|
||||||
ephemeral=self.ephemeral,
|
ephemeral=self.ephemeral,
|
||||||
|
to_device_messages=self.to_device_messages,
|
||||||
txn_id=self.id,
|
txn_id=self.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
service: "ApplicationService",
|
service: "ApplicationService",
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: List[JsonDict],
|
ephemeral: List[JsonDict],
|
||||||
|
to_device_messages: List[JsonDict],
|
||||||
txn_id: Optional[int] = None,
|
txn_id: Optional[int] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Push data to an application service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: The application service to send to.
|
||||||
|
events: The persistent events to send.
|
||||||
|
ephemeral: The ephemeral events to send.
|
||||||
|
to_device_messages: The to-device messages to send.
|
||||||
|
txn_id: An unique ID to assign to this transaction. Application services should
|
||||||
|
deduplicate transactions received with identitical IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the task succeeded, False if it failed.
|
||||||
|
"""
|
||||||
if service.url is None:
|
if service.url is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
|
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
|
||||||
|
|
||||||
# Never send ephemeral events to appservices that do not support it
|
# Never send ephemeral events to appservices that do not support it
|
||||||
|
body: Dict[str, List[JsonDict]] = {"events": serialized_events}
|
||||||
if service.supports_ephemeral:
|
if service.supports_ephemeral:
|
||||||
body = {
|
body.update(
|
||||||
"events": serialized_events,
|
{
|
||||||
|
# TODO: Update to stable prefixes once MSC2409 completes FCP merge.
|
||||||
"de.sorunome.msc2409.ephemeral": ephemeral,
|
"de.sorunome.msc2409.ephemeral": ephemeral,
|
||||||
|
"de.sorunome.msc2409.to_device": to_device_messages,
|
||||||
}
|
}
|
||||||
else:
|
)
|
||||||
body = {"events": serialized_events}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.put_json(
|
await self.put_json(
|
||||||
|
@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required
|
|||||||
components.
|
components.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, ApplicationServiceState
|
from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||||
from synapse.appservice.api import ApplicationServiceApi
|
from synapse.appservice.api import ApplicationServiceApi
|
||||||
@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
|
|||||||
# Maximum number of ephemeral events to provide in an AS transaction.
|
# Maximum number of ephemeral events to provide in an AS transaction.
|
||||||
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
|
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
|
||||||
|
|
||||||
|
# Maximum number of to-device messages to provide in an AS transaction.
|
||||||
|
MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceScheduler:
|
class ApplicationServiceScheduler:
|
||||||
"""Public facing API for this module. Does the required DI to tie the
|
"""Public facing API for this module. Does the required DI to tie the
|
||||||
@ -97,15 +109,40 @@ class ApplicationServiceScheduler:
|
|||||||
for service in services:
|
for service in services:
|
||||||
self.txn_ctrl.start_recoverer(service)
|
self.txn_ctrl.start_recoverer(service)
|
||||||
|
|
||||||
def submit_event_for_as(
|
def enqueue_for_appservice(
|
||||||
self, service: ApplicationService, event: EventBase
|
self,
|
||||||
|
appservice: ApplicationService,
|
||||||
|
events: Optional[Collection[EventBase]] = None,
|
||||||
|
ephemeral: Optional[Collection[JsonDict]] = None,
|
||||||
|
to_device_messages: Optional[Collection[JsonDict]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.queuer.enqueue_event(service, event)
|
"""
|
||||||
|
Enqueue some data to be sent off to an application service.
|
||||||
|
|
||||||
def submit_ephemeral_events_for_as(
|
Args:
|
||||||
self, service: ApplicationService, events: List[JsonDict]
|
appservice: The application service to create and send a transaction to.
|
||||||
) -> None:
|
events: The persistent room events to send.
|
||||||
self.queuer.enqueue_ephemeral(service, events)
|
ephemeral: The ephemeral events to send.
|
||||||
|
to_device_messages: The to-device messages to send. These differ from normal
|
||||||
|
to-device messages sent to clients, as they have 'to_device_id' and
|
||||||
|
'to_user_id' fields.
|
||||||
|
"""
|
||||||
|
# We purposefully allow this method to run with empty events/ephemeral
|
||||||
|
# collections, so that callers do not need to check iterable size themselves.
|
||||||
|
if not events and not ephemeral and not to_device_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
if events:
|
||||||
|
self.queuer.queued_events.setdefault(appservice.id, []).extend(events)
|
||||||
|
if ephemeral:
|
||||||
|
self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral)
|
||||||
|
if to_device_messages:
|
||||||
|
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
|
||||||
|
to_device_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kick off a new application service transaction
|
||||||
|
self.queuer.start_background_request(appservice)
|
||||||
|
|
||||||
|
|
||||||
class _ServiceQueuer:
|
class _ServiceQueuer:
|
||||||
@ -121,13 +158,15 @@ class _ServiceQueuer:
|
|||||||
self.queued_events: Dict[str, List[EventBase]] = {}
|
self.queued_events: Dict[str, List[EventBase]] = {}
|
||||||
# dict of {service_id: [events]}
|
# dict of {service_id: [events]}
|
||||||
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
|
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
|
||||||
|
# dict of {service_id: [to_device_message_json]}
|
||||||
|
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
|
||||||
|
|
||||||
# the appservices which currently have a transaction in flight
|
# the appservices which currently have a transaction in flight
|
||||||
self.requests_in_flight: Set[str] = set()
|
self.requests_in_flight: Set[str] = set()
|
||||||
self.txn_ctrl = txn_ctrl
|
self.txn_ctrl = txn_ctrl
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
|
|
||||||
def _start_background_request(self, service: ApplicationService) -> None:
|
def start_background_request(self, service: ApplicationService) -> None:
|
||||||
# start a sender for this appservice if we don't already have one
|
# start a sender for this appservice if we don't already have one
|
||||||
if service.id in self.requests_in_flight:
|
if service.id in self.requests_in_flight:
|
||||||
return
|
return
|
||||||
@ -136,16 +175,6 @@ class _ServiceQueuer:
|
|||||||
"as-sender-%s" % (service.id,), self._send_request, service
|
"as-sender-%s" % (service.id,), self._send_request, service
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
|
|
||||||
self.queued_events.setdefault(service.id, []).append(event)
|
|
||||||
self._start_background_request(service)
|
|
||||||
|
|
||||||
def enqueue_ephemeral(
|
|
||||||
self, service: ApplicationService, events: List[JsonDict]
|
|
||||||
) -> None:
|
|
||||||
self.queued_ephemeral.setdefault(service.id, []).extend(events)
|
|
||||||
self._start_background_request(service)
|
|
||||||
|
|
||||||
async def _send_request(self, service: ApplicationService) -> None:
|
async def _send_request(self, service: ApplicationService) -> None:
|
||||||
# sanity-check: we shouldn't get here if this service already has a sender
|
# sanity-check: we shouldn't get here if this service already has a sender
|
||||||
# running.
|
# running.
|
||||||
@ -162,11 +191,21 @@ class _ServiceQueuer:
|
|||||||
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
||||||
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
||||||
|
|
||||||
if not events and not ephemeral:
|
all_to_device_messages = self.queued_to_device_messages.get(
|
||||||
|
service.id, []
|
||||||
|
)
|
||||||
|
to_device_messages_to_send = all_to_device_messages[
|
||||||
|
:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION
|
||||||
|
]
|
||||||
|
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
|
||||||
|
|
||||||
|
if not events and not ephemeral and not to_device_messages_to_send:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.txn_ctrl.send(service, events, ephemeral)
|
await self.txn_ctrl.send(
|
||||||
|
service, events, ephemeral, to_device_messages_to_send
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("AS request failed")
|
logger.exception("AS request failed")
|
||||||
finally:
|
finally:
|
||||||
@ -198,10 +237,24 @@ class _TransactionController:
|
|||||||
service: ApplicationService,
|
service: ApplicationService,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: Optional[List[JsonDict]] = None,
|
ephemeral: Optional[List[JsonDict]] = None,
|
||||||
|
to_device_messages: Optional[List[JsonDict]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a transaction with the given data and send to the provided
|
||||||
|
application service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: The application service to send the transaction to.
|
||||||
|
events: The persistent events to include in the transaction.
|
||||||
|
ephemeral: The ephemeral events to include in the transaction.
|
||||||
|
to_device_messages: The to-device messages to include in the transaction.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
txn = await self.store.create_appservice_txn(
|
txn = await self.store.create_appservice_txn(
|
||||||
service=service, events=events, ephemeral=ephemeral or []
|
service=service,
|
||||||
|
events=events,
|
||||||
|
ephemeral=ephemeral or [],
|
||||||
|
to_device_messages=to_device_messages or [],
|
||||||
)
|
)
|
||||||
service_is_up = await self._is_service_up(service)
|
service_is_up = await self._is_service_up(service)
|
||||||
if service_is_up:
|
if service_is_up:
|
||||||
|
@ -52,3 +52,10 @@ class ExperimentalConfig(Config):
|
|||||||
self.msc3202_device_masquerading_enabled: bool = experimental.get(
|
self.msc3202_device_masquerading_enabled: bool = experimental.get(
|
||||||
"msc3202_device_masquerading", False
|
"msc3202_device_masquerading", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MSC2409 (this setting only relates to optionally sending to-device messages).
|
||||||
|
# Presence, typing and read receipt EDUs are already sent to application services that
|
||||||
|
# have opted in to receive them. If enabled, this adds to-device messages to that list.
|
||||||
|
self.msc2409_to_device_messages_enabled: bool = experimental.get(
|
||||||
|
"msc2409_to_device_messages_enabled", False
|
||||||
|
)
|
||||||
|
@ -55,6 +55,9 @@ class ApplicationServicesHandler:
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.notify_appservices = hs.config.appservice.notify_appservices
|
self.notify_appservices = hs.config.appservice.notify_appservices
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
|
self._msc2409_to_device_messages_enabled = (
|
||||||
|
hs.config.experimental.msc2409_to_device_messages_enabled
|
||||||
|
)
|
||||||
|
|
||||||
self.current_max = 0
|
self.current_max = 0
|
||||||
self.is_processing = False
|
self.is_processing = False
|
||||||
@ -132,7 +135,9 @@ class ApplicationServicesHandler:
|
|||||||
|
|
||||||
# Fork off pushes to these services
|
# Fork off pushes to these services
|
||||||
for service in services:
|
for service in services:
|
||||||
self.scheduler.submit_event_for_as(service, event)
|
self.scheduler.enqueue_for_appservice(
|
||||||
|
service, events=[event]
|
||||||
|
)
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
ts = await self.store.get_received_ts(event.event_id)
|
ts = await self.store.get_received_ts(event.event_id)
|
||||||
@ -199,8 +204,9 @@ class ApplicationServicesHandler:
|
|||||||
Args:
|
Args:
|
||||||
stream_key: The stream the event came from.
|
stream_key: The stream the event came from.
|
||||||
|
|
||||||
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
|
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
|
||||||
value for `stream_key` will cause this function to return early.
|
"to_device_key". Any other value for `stream_key` will cause this function
|
||||||
|
to return early.
|
||||||
|
|
||||||
Ephemeral events will only be pushed to appservices that have opted into
|
Ephemeral events will only be pushed to appservices that have opted into
|
||||||
receiving them by setting `push_ephemeral` to true in their registration
|
receiving them by setting `push_ephemeral` to true in their registration
|
||||||
@ -216,8 +222,15 @@ class ApplicationServicesHandler:
|
|||||||
if not self.notify_appservices:
|
if not self.notify_appservices:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ignore any unsupported streams
|
# Notify appservices of updates in ephemeral event streams.
|
||||||
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
|
# Only the following streams are currently supported.
|
||||||
|
# FIXME: We should use constants for these values.
|
||||||
|
if stream_key not in (
|
||||||
|
"typing_key",
|
||||||
|
"receipt_key",
|
||||||
|
"presence_key",
|
||||||
|
"to_device_key",
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Assert that new_token is an integer (and not a RoomStreamToken).
|
# Assert that new_token is an integer (and not a RoomStreamToken).
|
||||||
@ -233,6 +246,13 @@ class ApplicationServicesHandler:
|
|||||||
# Additional context: https://github.com/matrix-org/synapse/pull/11137
|
# Additional context: https://github.com/matrix-org/synapse/pull/11137
|
||||||
assert isinstance(new_token, int)
|
assert isinstance(new_token, int)
|
||||||
|
|
||||||
|
# Ignore to-device messages if the feature flag is not enabled
|
||||||
|
if (
|
||||||
|
stream_key == "to_device_key"
|
||||||
|
and not self._msc2409_to_device_messages_enabled
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
# Check whether there are any appservices which have registered to receive
|
# Check whether there are any appservices which have registered to receive
|
||||||
# ephemeral events.
|
# ephemeral events.
|
||||||
#
|
#
|
||||||
@ -266,7 +286,7 @@ class ApplicationServicesHandler:
|
|||||||
with Measure(self.clock, "notify_interested_services_ephemeral"):
|
with Measure(self.clock, "notify_interested_services_ephemeral"):
|
||||||
for service in services:
|
for service in services:
|
||||||
if stream_key == "typing_key":
|
if stream_key == "typing_key":
|
||||||
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
|
# Note that we don't persist the token (via set_appservice_stream_type_pos)
|
||||||
# for typing_key due to performance reasons and due to their highly
|
# for typing_key due to performance reasons and due to their highly
|
||||||
# ephemeral nature.
|
# ephemeral nature.
|
||||||
#
|
#
|
||||||
@ -274,7 +294,7 @@ class ApplicationServicesHandler:
|
|||||||
# and, if they apply to this application service, send it off.
|
# and, if they apply to this application service, send it off.
|
||||||
events = await self._handle_typing(service, new_token)
|
events = await self._handle_typing(service, new_token)
|
||||||
if events:
|
if events:
|
||||||
self.scheduler.submit_ephemeral_events_for_as(service, events)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Since we read/update the stream position for this AS/stream
|
# Since we read/update the stream position for this AS/stream
|
||||||
@ -285,26 +305,35 @@ class ApplicationServicesHandler:
|
|||||||
):
|
):
|
||||||
if stream_key == "receipt_key":
|
if stream_key == "receipt_key":
|
||||||
events = await self._handle_receipts(service, new_token)
|
events = await self._handle_receipts(service, new_token)
|
||||||
if events:
|
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
|
||||||
self.scheduler.submit_ephemeral_events_for_as(
|
|
||||||
service, events
|
|
||||||
)
|
|
||||||
|
|
||||||
# Persist the latest handled stream token for this appservice
|
# Persist the latest handled stream token for this appservice
|
||||||
await self.store.set_type_stream_id_for_appservice(
|
await self.store.set_appservice_stream_type_pos(
|
||||||
service, "read_receipt", new_token
|
service, "read_receipt", new_token
|
||||||
)
|
)
|
||||||
|
|
||||||
elif stream_key == "presence_key":
|
elif stream_key == "presence_key":
|
||||||
events = await self._handle_presence(service, users, new_token)
|
events = await self._handle_presence(service, users, new_token)
|
||||||
if events:
|
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
|
||||||
self.scheduler.submit_ephemeral_events_for_as(
|
|
||||||
service, events
|
# Persist the latest handled stream token for this appservice
|
||||||
|
await self.store.set_appservice_stream_type_pos(
|
||||||
|
service, "presence", new_token
|
||||||
|
)
|
||||||
|
|
||||||
|
elif stream_key == "to_device_key":
|
||||||
|
# Retrieve a list of to-device message events, as well as the
|
||||||
|
# maximum stream token of the messages we were able to retrieve.
|
||||||
|
to_device_messages = await self._get_to_device_messages(
|
||||||
|
service, new_token, users
|
||||||
|
)
|
||||||
|
self.scheduler.enqueue_for_appservice(
|
||||||
|
service, to_device_messages=to_device_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
# Persist the latest handled stream token for this appservice
|
# Persist the latest handled stream token for this appservice
|
||||||
await self.store.set_type_stream_id_for_appservice(
|
await self.store.set_appservice_stream_type_pos(
|
||||||
service, "presence", new_token
|
service, "to_device", new_token
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_typing(
|
async def _handle_typing(
|
||||||
@ -440,6 +469,79 @@ class ApplicationServicesHandler:
|
|||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
|
async def _get_to_device_messages(
|
||||||
|
self,
|
||||||
|
service: ApplicationService,
|
||||||
|
new_token: int,
|
||||||
|
users: Collection[Union[str, UserID]],
|
||||||
|
) -> List[JsonDict]:
|
||||||
|
"""
|
||||||
|
Given an application service, determine which events it should receive
|
||||||
|
from those between the last-recorded to-device message stream token for this
|
||||||
|
appservice and the given stream token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: The application service to check for which events it should receive.
|
||||||
|
new_token: The latest to-device event stream token.
|
||||||
|
users: The users to be notified for the new to-device messages
|
||||||
|
(ie, the recipients of the messages).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of JSON dictionaries containing data derived from the to-device events
|
||||||
|
that should be sent to the given application service.
|
||||||
|
"""
|
||||||
|
# Get the stream token that this application service has processed up until
|
||||||
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
|
service, "to_device"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out users that this appservice is not interested in
|
||||||
|
users_appservice_is_interested_in: List[str] = []
|
||||||
|
for user in users:
|
||||||
|
# FIXME: We should do this farther up the call stack. We currently repeat
|
||||||
|
# this operation in _handle_presence.
|
||||||
|
if isinstance(user, UserID):
|
||||||
|
user = user.to_string()
|
||||||
|
|
||||||
|
if service.is_interested_in_user(user):
|
||||||
|
users_appservice_is_interested_in.append(user)
|
||||||
|
|
||||||
|
if not users_appservice_is_interested_in:
|
||||||
|
# Return early if the AS was not interested in any of these users
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Retrieve the to-device messages for each user
|
||||||
|
recipient_device_to_messages = await self.store.get_messages_for_user_devices(
|
||||||
|
users_appservice_is_interested_in,
|
||||||
|
from_key,
|
||||||
|
new_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
|
||||||
|
# to the event JSON so that the application service will know which user/device
|
||||||
|
# combination this messages was intended for.
|
||||||
|
#
|
||||||
|
# So we mangle this dict into a flat list of to-device messages with the relevant
|
||||||
|
# user ID and device ID embedded inside each message dict.
|
||||||
|
message_payload: List[JsonDict] = []
|
||||||
|
for (
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
), messages in recipient_device_to_messages.items():
|
||||||
|
for message_json in messages:
|
||||||
|
# Remove 'message_id' from the to-device message, as it's an internal ID
|
||||||
|
message_json.pop("message_id", None)
|
||||||
|
|
||||||
|
message_payload.append(
|
||||||
|
{
|
||||||
|
"to_user_id": user_id,
|
||||||
|
"to_device_id": device_id,
|
||||||
|
**message_json,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_payload
|
||||||
|
|
||||||
async def query_user_exists(self, user_id: str) -> bool:
|
async def query_user_exists(self, user_id: str) -> bool:
|
||||||
"""Check if any application service knows this user_id exists.
|
"""Check if any application service knows this user_id exists.
|
||||||
|
|
||||||
|
@ -1348,8 +1348,8 @@ class SyncHandler:
|
|||||||
if sync_result_builder.since_token is not None:
|
if sync_result_builder.since_token is not None:
|
||||||
since_stream_id = int(sync_result_builder.since_token.to_device_key)
|
since_stream_id = int(sync_result_builder.since_token.to_device_key)
|
||||||
|
|
||||||
if since_stream_id != int(now_token.to_device_key):
|
if device_id is not None and since_stream_id != int(now_token.to_device_key):
|
||||||
messages, stream_id = await self.store.get_new_messages_for_device(
|
messages, stream_id = await self.store.get_messages_for_device(
|
||||||
user_id, device_id, since_stream_id, now_token.to_device_key
|
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -461,7 +461,9 @@ class Notifier:
|
|||||||
users,
|
users,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error notifying application services of event")
|
logger.exception(
|
||||||
|
"Error notifying application services of ephemeral events"
|
||||||
|
)
|
||||||
|
|
||||||
def on_new_replication_data(self) -> None:
|
def on_new_replication_data(self) -> None:
|
||||||
"""Used to inform replication listeners that something has happened
|
"""Used to inform replication listeners that something has happened
|
||||||
|
@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
service: ApplicationService,
|
service: ApplicationService,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: List[JsonDict],
|
ephemeral: List[JsonDict],
|
||||||
|
to_device_messages: List[JsonDict],
|
||||||
) -> AppServiceTransaction:
|
) -> AppServiceTransaction:
|
||||||
"""Atomically creates a new transaction for this application service
|
"""Atomically creates a new transaction for this application service
|
||||||
with the given list of events. Ephemeral events are NOT persisted to the
|
with the given list of events. Ephemeral events are NOT persisted to the
|
||||||
@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
service: The service who the transaction is for.
|
service: The service who the transaction is for.
|
||||||
events: A list of persistent events to put in the transaction.
|
events: A list of persistent events to put in the transaction.
|
||||||
ephemeral: A list of ephemeral events to put in the transaction.
|
ephemeral: A list of ephemeral events to put in the transaction.
|
||||||
|
to_device_messages: A list of to-device messages to put in the transaction.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new transaction.
|
A new transaction.
|
||||||
@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
(service.id, new_txn_id, event_ids),
|
(service.id, new_txn_id, event_ids),
|
||||||
)
|
)
|
||||||
return AppServiceTransaction(
|
return AppServiceTransaction(
|
||||||
service=service, id=new_txn_id, events=events, ephemeral=ephemeral
|
service=service,
|
||||||
|
id=new_txn_id,
|
||||||
|
events=events,
|
||||||
|
ephemeral=ephemeral,
|
||||||
|
to_device_messages=to_device_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
events = await self.get_events_as_list(event_ids)
|
events = await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
return AppServiceTransaction(
|
return AppServiceTransaction(
|
||||||
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
service=service,
|
||||||
|
id=entry["txn_id"],
|
||||||
|
events=events,
|
||||||
|
ephemeral=[],
|
||||||
|
to_device_messages=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||||
@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
async def get_type_stream_id_for_appservice(
|
async def get_type_stream_id_for_appservice(
|
||||||
self, service: ApplicationService, type: str
|
self, service: ApplicationService, type: str
|
||||||
) -> int:
|
) -> int:
|
||||||
if type not in ("read_receipt", "presence"):
|
if type not in ("read_receipt", "presence", "to_device"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected type to be a valid application stream id type, got %s"
|
"Expected type to be a valid application stream id type, got %s"
|
||||||
% (type,)
|
% (type,)
|
||||||
@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
|
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_type_stream_id_for_appservice(
|
async def set_appservice_stream_type_pos(
|
||||||
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
if stream_type not in ("read_receipt", "presence"):
|
if stream_type not in ("read_receipt", "presence", "to_device"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected type to be a valid application stream id type, got %s"
|
"Expected type to be a valid application stream id type, got %s"
|
||||||
% (stream_type,)
|
% (stream_type,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_type_stream_id_for_appservice_txn(txn):
|
def set_appservice_stream_type_pos_txn(txn):
|
||||||
stream_id_type = "%s_stream_id" % stream_type
|
stream_id_type = "%s_stream_id" % stream_type
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
||||||
@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
|
"set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
|
||||||
|
|
||||||
from synapse.logging import issue9533_logger
|
from synapse.logging import issue9533_logger
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
@ -24,6 +24,7 @@ from synapse.storage.database import (
|
|||||||
DatabasePool,
|
DatabasePool,
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
def get_to_device_stream_token(self):
|
def get_to_device_stream_token(self):
|
||||||
return self._device_inbox_id_gen.get_current_token()
|
return self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
||||||
async def get_new_messages_for_device(
|
async def get_messages_for_user_devices(
|
||||||
|
self,
|
||||||
|
user_ids: Collection[str],
|
||||||
|
from_stream_id: int,
|
||||||
|
to_stream_id: int,
|
||||||
|
) -> Dict[Tuple[str, str], List[JsonDict]]:
|
||||||
|
"""
|
||||||
|
Retrieve to-device messages for a given set of users.
|
||||||
|
|
||||||
|
Only to-device messages with stream ids between the given boundaries
|
||||||
|
(from < X <= to) are returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids: The users to retrieve to-device messages for.
|
||||||
|
from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
||||||
|
to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of (user id, device id) -> list of to-device messages.
|
||||||
|
"""
|
||||||
|
# We expect the stream ID returned by _get_device_messages to always
|
||||||
|
# be to_stream_id. So, no need to return it from this function.
|
||||||
|
(
|
||||||
|
user_id_device_id_to_messages,
|
||||||
|
last_processed_stream_id,
|
||||||
|
) = await self._get_device_messages(
|
||||||
|
user_ids=user_ids,
|
||||||
|
from_stream_id=from_stream_id,
|
||||||
|
to_stream_id=to_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
last_processed_stream_id == to_stream_id
|
||||||
|
), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
|
||||||
|
|
||||||
|
return user_id_device_id_to_messages
|
||||||
|
|
||||||
|
async def get_messages_for_device(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_id: Optional[str],
|
device_id: str,
|
||||||
last_stream_id: int,
|
from_stream_id: int,
|
||||||
current_stream_id: int,
|
to_stream_id: int,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> Tuple[List[dict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""
|
"""
|
||||||
|
Retrieve to-device messages for a single user device.
|
||||||
|
|
||||||
|
Only to-device messages with stream ids between the given boundaries
|
||||||
|
(from < X <= to) are returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The recipient user_id.
|
user_id: The ID of the user to retrieve messages for.
|
||||||
device_id: The recipient device_id.
|
device_id: The ID of the device to retrieve to-device messages for.
|
||||||
last_stream_id: The last stream ID checked.
|
from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
||||||
current_stream_id: The current position of the to device
|
to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
||||||
message stream.
|
limit: A limit on the number of to-device messages returned.
|
||||||
limit: The maximum number of messages to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
* A list of messages for the device.
|
* A list of to-device messages within the given stream id range intended for
|
||||||
* The max stream token of these messages. There may be more to retrieve
|
the given user / device combo.
|
||||||
if the given limit was reached.
|
* The last-processed stream ID. Subsequent calls of this function with the
|
||||||
|
same device should pass this value as 'from_stream_id'.
|
||||||
"""
|
"""
|
||||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
(
|
||||||
user_id, last_stream_id
|
user_id_device_id_to_messages,
|
||||||
)
|
last_processed_stream_id,
|
||||||
if not has_changed:
|
) = await self._get_device_messages(
|
||||||
return [], current_stream_id
|
user_ids=[user_id],
|
||||||
|
device_id=device_id,
|
||||||
def get_new_messages_for_device_txn(txn):
|
from_stream_id=from_stream_id,
|
||||||
sql = (
|
to_stream_id=to_stream_id,
|
||||||
"SELECT stream_id, message_json FROM device_inbox"
|
limit=limit,
|
||||||
" WHERE user_id = ? AND device_id = ?"
|
|
||||||
" AND ? < stream_id AND stream_id <= ?"
|
|
||||||
" ORDER BY stream_id ASC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = []
|
if not user_id_device_id_to_messages:
|
||||||
stream_pos = current_stream_id
|
# There were no messages!
|
||||||
|
return [], to_stream_id
|
||||||
|
|
||||||
|
# Extract the messages, no need to return the user and device ID again
|
||||||
|
to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
|
||||||
|
|
||||||
|
return to_device_messages, last_processed_stream_id
|
||||||
|
|
||||||
|
async def _get_device_messages(
|
||||||
|
self,
|
||||||
|
user_ids: Collection[str],
|
||||||
|
from_stream_id: int,
|
||||||
|
to_stream_id: int,
|
||||||
|
device_id: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
|
||||||
|
"""
|
||||||
|
Retrieve pending to-device messages for a collection of user devices.
|
||||||
|
|
||||||
|
Only to-device messages with stream ids between the given boundaries
|
||||||
|
(from < X <= to) are returned.
|
||||||
|
|
||||||
|
Note that a stream ID can be shared by multiple copies of the same message with
|
||||||
|
different recipient devices. Stream IDs are only unique in the context of a single
|
||||||
|
user ID / device ID pair. Thus, applying a limit (of messages to return) when working
|
||||||
|
with a sliding window of stream IDs is only possible when querying messages of a
|
||||||
|
single user device.
|
||||||
|
|
||||||
|
Finally, note that device IDs are not unique across users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids: The user IDs to filter device messages by.
|
||||||
|
from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
||||||
|
to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
||||||
|
device_id: A device ID to query to-device messages for. If not provided, to-device
|
||||||
|
messages from all device IDs for the given user IDs will be queried. May not be
|
||||||
|
provided if `user_ids` contains more than one entry.
|
||||||
|
limit: The maximum number of to-device messages to return. Can only be used when
|
||||||
|
passing a single user ID / device ID tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
* A dict of (user_id, device_id) -> list of to-device messages
|
||||||
|
* The last-processed stream ID. If this is less than `to_stream_id`, then
|
||||||
|
there may be more messages to retrieve. If `limit` is not set, then this
|
||||||
|
is always equal to 'to_stream_id'.
|
||||||
|
"""
|
||||||
|
if not user_ids:
|
||||||
|
logger.warning("No users provided upon querying for device IDs")
|
||||||
|
return {}, to_stream_id
|
||||||
|
|
||||||
|
# Prevent a query for one user's device also retrieving another user's device with
|
||||||
|
# the same device ID (device IDs are not unique across users).
|
||||||
|
if len(user_ids) > 1 and device_id is not None:
|
||||||
|
raise AssertionError(
|
||||||
|
"Programming error: 'device_id' cannot be supplied to "
|
||||||
|
"_get_device_messages when >1 user_id has been provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# A limit can only be applied when querying for a single user ID / device ID tuple.
|
||||||
|
# See the docstring of this function for more details.
|
||||||
|
if limit is not None and device_id is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"Programming error: _get_device_messages was passed 'limit' "
|
||||||
|
"without a specific user_id/device_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids_to_query: Set[str] = set()
|
||||||
|
device_ids_to_query: Set[str] = set()
|
||||||
|
|
||||||
|
# Note that a device ID could be an empty str
|
||||||
|
if device_id is not None:
|
||||||
|
# If a device ID was passed, use it to filter results.
|
||||||
|
# Otherwise, device IDs will be derived from the given collection of user IDs.
|
||||||
|
device_ids_to_query.add(device_id)
|
||||||
|
|
||||||
|
# Determine which users have devices with pending messages
|
||||||
|
for user_id in user_ids:
|
||||||
|
if self._device_inbox_stream_cache.has_entity_changed(
|
||||||
|
user_id, from_stream_id
|
||||||
|
):
|
||||||
|
# This user has new messages sent to them. Query messages for them
|
||||||
|
user_ids_to_query.add(user_id)
|
||||||
|
|
||||||
|
def get_device_messages_txn(txn: LoggingTransaction):
|
||||||
|
# Build a query to select messages from any of the given devices that
|
||||||
|
# are between the given stream id bounds.
|
||||||
|
|
||||||
|
# If a list of device IDs was not provided, retrieve all devices IDs
|
||||||
|
# for the given users. We explicitly do not query hidden devices, as
|
||||||
|
# hidden devices should not receive to-device messages.
|
||||||
|
# Note that this is more efficient than just dropping `device_id` from the query,
|
||||||
|
# since device_inbox has an index on `(user_id, device_id, stream_id)`
|
||||||
|
if not device_ids_to_query:
|
||||||
|
user_device_dicts = self.db_pool.simple_select_many_txn(
|
||||||
|
txn,
|
||||||
|
table="devices",
|
||||||
|
column="user_id",
|
||||||
|
iterable=user_ids_to_query,
|
||||||
|
keyvalues={"user_id": user_id, "hidden": False},
|
||||||
|
retcols=("device_id",),
|
||||||
|
)
|
||||||
|
|
||||||
|
device_ids_to_query.update(
|
||||||
|
{row["device_id"] for row in user_device_dicts}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not device_ids_to_query:
|
||||||
|
# We've ended up with no devices to query.
|
||||||
|
return {}, to_stream_id
|
||||||
|
|
||||||
|
# We include both user IDs and device IDs in this query, as we have an index
|
||||||
|
# (device_inbox_user_stream_id) for them.
|
||||||
|
user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
|
||||||
|
self.database_engine, "user_id", user_ids_to_query
|
||||||
|
)
|
||||||
|
(
|
||||||
|
device_id_many_clause_sql,
|
||||||
|
device_id_many_clause_args,
|
||||||
|
) = make_in_list_sql_clause(
|
||||||
|
self.database_engine, "device_id", device_ids_to_query
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
|
||||||
|
WHERE {user_id_many_clause_sql}
|
||||||
|
AND {device_id_many_clause_sql}
|
||||||
|
AND ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
"""
|
||||||
|
sql_args = (
|
||||||
|
*user_id_many_clause_args,
|
||||||
|
*device_id_many_clause_args,
|
||||||
|
from_stream_id,
|
||||||
|
to_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If a limit was provided, limit the data retrieved from the database
|
||||||
|
if limit is not None:
|
||||||
|
sql += "LIMIT ?"
|
||||||
|
sql_args += (limit,)
|
||||||
|
|
||||||
|
txn.execute(sql, sql_args)
|
||||||
|
|
||||||
|
# Create and fill a dictionary of (user ID, device ID) -> list of messages
|
||||||
|
# intended for each device.
|
||||||
|
last_processed_stream_pos = to_stream_id
|
||||||
|
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
|
||||||
for row in txn:
|
for row in txn:
|
||||||
stream_pos = row[0]
|
last_processed_stream_pos = row[0]
|
||||||
messages.append(db_to_json(row[1]))
|
recipient_user_id = row[1]
|
||||||
|
recipient_device_id = row[2]
|
||||||
|
message_dict = db_to_json(row[3])
|
||||||
|
|
||||||
# If the limit was not reached we know that there's no more data for this
|
# Store the device details
|
||||||
# user/device pair up to current_stream_id.
|
recipient_device_to_messages.setdefault(
|
||||||
if len(messages) < limit:
|
(recipient_user_id, recipient_device_id), []
|
||||||
stream_pos = current_stream_id
|
).append(message_dict)
|
||||||
|
|
||||||
return messages, stream_pos
|
if limit is not None and txn.rowcount == limit:
|
||||||
|
# We ended up bumping up against the message limit. There may be more messages
|
||||||
|
# to retrieve. Return what we have, as well as the last stream position that
|
||||||
|
# was processed.
|
||||||
|
#
|
||||||
|
# The caller is expected to set this as the lower (exclusive) bound
|
||||||
|
# for the next query of this device.
|
||||||
|
return recipient_device_to_messages, last_processed_stream_pos
|
||||||
|
|
||||||
|
# The limit was not reached, thus we know that recipient_device_to_messages
|
||||||
|
# contains all to-device messages for the given device and stream id range.
|
||||||
|
#
|
||||||
|
# We return to_stream_id, which the caller should then provide as the lower
|
||||||
|
# (exclusive) bound on the next query of this device.
|
||||||
|
return recipient_device_to_messages, to_stream_id
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_new_messages_for_device", get_new_messages_for_device_txn
|
"get_device_messages", get_device_messages_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@ -0,0 +1,21 @@
|
|||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Add a column to track what to_device stream id that this application
|
||||||
|
-- service has been caught up to.
|
||||||
|
|
||||||
|
-- NULL indicates that this appservice has never received any to_device messages. This
|
||||||
|
-- can be used, for example, to avoid sending a huge dump of messages at startup.
|
||||||
|
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;
|
@ -11,23 +11,29 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.appservice import ApplicationServiceState
|
from synapse.appservice import ApplicationServiceState
|
||||||
from synapse.appservice.scheduler import (
|
from synapse.appservice.scheduler import (
|
||||||
|
ApplicationServiceScheduler,
|
||||||
_Recoverer,
|
_Recoverer,
|
||||||
_ServiceQueuer,
|
|
||||||
_TransactionController,
|
_TransactionController,
|
||||||
)
|
)
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
|
|
||||||
from ..utils import MockClock
|
from ..utils import MockClock
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from twisted.internet.testing import MemoryReactor
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
|
||||||
self.store.create_appservice_txn.assert_called_once_with(
|
self.store.create_appservice_txn.assert_called_once_with(
|
||||||
service=service, events=events, ephemeral=[] # txn made and saved
|
service=service,
|
||||||
|
events=events,
|
||||||
|
ephemeral=[],
|
||||||
|
to_device_messages=[], # txn made and saved
|
||||||
)
|
)
|
||||||
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
|
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
|
||||||
txn.complete.assert_called_once_with(self.store) # txn completed
|
txn.complete.assert_called_once_with(self.store) # txn completed
|
||||||
@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
|
||||||
self.store.create_appservice_txn.assert_called_once_with(
|
self.store.create_appservice_txn.assert_called_once_with(
|
||||||
service=service, events=events, ephemeral=[] # txn made and saved
|
service=service,
|
||||||
|
events=events,
|
||||||
|
ephemeral=[],
|
||||||
|
to_device_messages=[], # txn made and saved
|
||||||
)
|
)
|
||||||
self.assertEquals(0, txn.send.call_count) # txn not sent though
|
self.assertEquals(0, txn.send.call_count) # txn not sent though
|
||||||
self.assertEquals(0, txn.complete.call_count) # or completed
|
self.assertEquals(0, txn.complete.call_count) # or completed
|
||||||
@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
|
||||||
self.store.create_appservice_txn.assert_called_once_with(
|
self.store.create_appservice_txn.assert_called_once_with(
|
||||||
service=service, events=events, ephemeral=[]
|
service=service, events=events, ephemeral=[], to_device_messages=[]
|
||||||
)
|
)
|
||||||
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
|
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
|
||||||
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
|
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
|
||||||
@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||||||
self.callback.assert_called_once_with(self.recoverer)
|
self.callback.assert_called_once_with(self.recoverer)
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer):
|
||||||
|
self.scheduler = ApplicationServiceScheduler(hs)
|
||||||
self.txn_ctrl = Mock()
|
self.txn_ctrl = Mock()
|
||||||
self.txn_ctrl.send = simple_async_mock()
|
self.txn_ctrl.send = simple_async_mock()
|
||||||
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
|
||||||
|
# Replace instantiated _TransactionController instances with our Mock
|
||||||
|
self.scheduler.txn_ctrl = self.txn_ctrl
|
||||||
|
self.scheduler.queuer.txn_ctrl = self.txn_ctrl
|
||||||
|
|
||||||
def test_send_single_event_no_queue(self):
|
def test_send_single_event_no_queue(self):
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
service = Mock(id=4)
|
service = Mock(id=4)
|
||||||
event = Mock()
|
event = Mock()
|
||||||
self.queuer.enqueue_event(service, event)
|
self.scheduler.enqueue_for_appservice(service, events=[event])
|
||||||
self.txn_ctrl.send.assert_called_once_with(service, [event], [])
|
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
|
||||||
|
|
||||||
def test_send_single_event_with_queue(self):
|
def test_send_single_event_with_queue(self):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self.txn_ctrl.send = Mock(
|
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
|
||||||
side_effect=lambda x, y, z: make_deferred_yieldable(d)
|
|
||||||
)
|
|
||||||
service = Mock(id=4)
|
service = Mock(id=4)
|
||||||
event = Mock(event_id="first")
|
event = Mock(event_id="first")
|
||||||
event2 = Mock(event_id="second")
|
event2 = Mock(event_id="second")
|
||||||
event3 = Mock(event_id="third")
|
event3 = Mock(event_id="third")
|
||||||
# Send an event and don't resolve it just yet.
|
# Send an event and don't resolve it just yet.
|
||||||
self.queuer.enqueue_event(service, event)
|
self.scheduler.enqueue_for_appservice(service, events=[event])
|
||||||
# Send more events: expect send() to NOT be called multiple times.
|
# Send more events: expect send() to NOT be called multiple times.
|
||||||
self.queuer.enqueue_event(service, event2)
|
# (call enqueue_for_appservice multiple times deliberately)
|
||||||
self.queuer.enqueue_event(service, event3)
|
self.scheduler.enqueue_for_appservice(service, events=[event2])
|
||||||
self.txn_ctrl.send.assert_called_with(service, [event], [])
|
self.scheduler.enqueue_for_appservice(service, events=[event3])
|
||||||
|
self.txn_ctrl.send.assert_called_with(service, [event], [], [])
|
||||||
self.assertEquals(1, self.txn_ctrl.send.call_count)
|
self.assertEquals(1, self.txn_ctrl.send.call_count)
|
||||||
# Resolve the send event: expect the queued events to be sent
|
# Resolve the send event: expect the queued events to be sent
|
||||||
d.callback(service)
|
d.callback(service)
|
||||||
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
|
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
|
||||||
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
def test_multiple_service_queues(self):
|
def test_multiple_service_queues(self):
|
||||||
@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
send_return_list = [srv_1_defer, srv_2_defer]
|
send_return_list = [srv_1_defer, srv_2_defer]
|
||||||
|
|
||||||
def do_send(x, y, z):
|
def do_send(*args, **kwargs):
|
||||||
return make_deferred_yieldable(send_return_list.pop(0))
|
return make_deferred_yieldable(send_return_list.pop(0))
|
||||||
|
|
||||||
self.txn_ctrl.send = Mock(side_effect=do_send)
|
self.txn_ctrl.send = Mock(side_effect=do_send)
|
||||||
|
|
||||||
# send events for different ASes and make sure they are sent
|
# send events for different ASes and make sure they are sent
|
||||||
self.queuer.enqueue_event(srv1, srv_1_event)
|
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
|
||||||
self.queuer.enqueue_event(srv1, srv_1_event2)
|
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
|
||||||
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
|
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
|
||||||
self.queuer.enqueue_event(srv2, srv_2_event)
|
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
|
||||||
self.queuer.enqueue_event(srv2, srv_2_event2)
|
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
|
||||||
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
|
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
|
||||||
|
|
||||||
# make sure callbacks for a service only send queued events for THAT
|
# make sure callbacks for a service only send queued events for THAT
|
||||||
# service
|
# service
|
||||||
srv_2_defer.callback(srv2)
|
srv_2_defer.callback(srv2)
|
||||||
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
|
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
|
||||||
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
def test_send_large_txns(self):
|
def test_send_large_txns(self):
|
||||||
@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
srv_2_defer = defer.Deferred()
|
srv_2_defer = defer.Deferred()
|
||||||
send_return_list = [srv_1_defer, srv_2_defer]
|
send_return_list = [srv_1_defer, srv_2_defer]
|
||||||
|
|
||||||
def do_send(x, y, z):
|
def do_send(*args, **kwargs):
|
||||||
return make_deferred_yieldable(send_return_list.pop(0))
|
return make_deferred_yieldable(send_return_list.pop(0))
|
||||||
|
|
||||||
self.txn_ctrl.send = Mock(side_effect=do_send)
|
self.txn_ctrl.send = Mock(side_effect=do_send)
|
||||||
@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||||||
service = Mock(id=4, name="service")
|
service = Mock(id=4, name="service")
|
||||||
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
|
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
|
||||||
for event in event_list:
|
for event in event_list:
|
||||||
self.queuer.enqueue_event(service, event)
|
self.scheduler.enqueue_for_appservice(service, [event], [])
|
||||||
|
|
||||||
# Expect the first event to be sent immediately.
|
# Expect the first event to be sent immediately.
|
||||||
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
|
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
|
||||||
srv_1_defer.callback(service)
|
srv_1_defer.callback(service)
|
||||||
# Then send the next 100 events
|
# Then send the next 100 events
|
||||||
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
|
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
|
||||||
srv_2_defer.callback(service)
|
srv_2_defer.callback(service)
|
||||||
# Then the final 99 events
|
# Then the final 99 events
|
||||||
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
|
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
|
||||||
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
def test_send_single_ephemeral_no_queue(self):
|
def test_send_single_ephemeral_no_queue(self):
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
service = Mock(id=4, name="service")
|
service = Mock(id=4, name="service")
|
||||||
event_list = [Mock(name="event")]
|
event_list = [Mock(name="event")]
|
||||||
self.queuer.enqueue_ephemeral(service, event_list)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||||
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
|
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
|
||||||
|
|
||||||
def test_send_multiple_ephemeral_no_queue(self):
|
def test_send_multiple_ephemeral_no_queue(self):
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
service = Mock(id=4, name="service")
|
service = Mock(id=4, name="service")
|
||||||
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
|
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
|
||||||
self.queuer.enqueue_ephemeral(service, event_list)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||||
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
|
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
|
||||||
|
|
||||||
def test_send_single_ephemeral_with_queue(self):
|
def test_send_single_ephemeral_with_queue(self):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self.txn_ctrl.send = Mock(
|
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
|
||||||
side_effect=lambda x, y, z: make_deferred_yieldable(d)
|
|
||||||
)
|
|
||||||
service = Mock(id=4)
|
service = Mock(id=4)
|
||||||
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
|
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
|
||||||
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
|
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
|
||||||
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
|
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
|
||||||
|
|
||||||
# Send an event and don't resolve it just yet.
|
# Send an event and don't resolve it just yet.
|
||||||
self.queuer.enqueue_ephemeral(service, event_list_1)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1)
|
||||||
# Send more events: expect send() to NOT be called multiple times.
|
# Send more events: expect send() to NOT be called multiple times.
|
||||||
self.queuer.enqueue_ephemeral(service, event_list_2)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
|
||||||
self.queuer.enqueue_ephemeral(service, event_list_3)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
|
||||||
self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
|
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
|
||||||
self.assertEquals(1, self.txn_ctrl.send.call_count)
|
self.assertEquals(1, self.txn_ctrl.send.call_count)
|
||||||
# Resolve txn_ctrl.send
|
# Resolve txn_ctrl.send
|
||||||
d.callback(service)
|
d.callback(service)
|
||||||
# Expect the queued events to be sent
|
# Expect the queued events to be sent
|
||||||
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
|
self.txn_ctrl.send.assert_called_with(
|
||||||
|
service, [], event_list_2 + event_list_3, []
|
||||||
|
)
|
||||||
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
def test_send_large_txns_ephemeral(self):
|
def test_send_large_txns_ephemeral(self):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self.txn_ctrl.send = Mock(
|
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
|
||||||
side_effect=lambda x, y, z: make_deferred_yieldable(d)
|
|
||||||
)
|
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
service = Mock(id=4, name="service")
|
service = Mock(id=4, name="service")
|
||||||
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
|
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
|
||||||
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
|
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
|
||||||
event_list = first_chunk + second_chunk
|
event_list = first_chunk + second_chunk
|
||||||
self.queuer.enqueue_ephemeral(service, event_list)
|
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||||
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
|
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
|
||||||
d.callback(service)
|
d.callback(service)
|
||||||
self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
|
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
|
||||||
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,18 +12,23 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.rest.admin
|
||||||
|
import synapse.storage
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
|
from synapse.rest.client import login, receipts, room, sendtodevice
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests.test_utils import make_awaitable
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable, simple_async_mock
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
|
|
||||||
from .. import unittest
|
|
||||||
|
|
||||||
|
|
||||||
class AppServiceHandlerTestCase(unittest.TestCase):
|
class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
"""Tests the ApplicationServicesHandler."""
|
"""Tests the ApplicationServicesHandler."""
|
||||||
@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
hs.get_datastore.return_value = self.mock_store
|
hs.get_datastore.return_value = self.mock_store
|
||||||
self.mock_store.get_received_ts.return_value = make_awaitable(0)
|
self.mock_store.get_received_ts.return_value = make_awaitable(0)
|
||||||
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
|
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
|
||||||
|
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
|
||||||
|
None
|
||||||
|
)
|
||||||
hs.get_application_service_api.return_value = self.mock_as_api
|
hs.get_application_service_api.return_value = self.mock_as_api
|
||||||
hs.get_application_service_scheduler.return_value = self.mock_scheduler
|
hs.get_application_service_scheduler.return_value = self.mock_scheduler
|
||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.handler.notify_interested_services(RoomStreamToken(None, 1))
|
self.handler.notify_interested_services(RoomStreamToken(None, 1))
|
||||||
|
|
||||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||||
interested_service, event
|
interested_service, events=[event]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_user_exists_unknown_user(self):
|
def test_query_user_exists_unknown_user(self):
|
||||||
@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
interested_service = self._mkservice(is_interested=True)
|
interested_service = self._mkservice(is_interested=True)
|
||||||
services = [interested_service]
|
services = [interested_service]
|
||||||
|
|
||||||
self.mock_store.get_app_services.return_value = services
|
self.mock_store.get_app_services.return_value = services
|
||||||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
||||||
579
|
579
|
||||||
@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
self.handler.notify_interested_services_ephemeral(
|
self.handler.notify_interested_services_ephemeral(
|
||||||
"receipt_key", 580, ["@fakerecipient:example.com"]
|
"receipt_key", 580, ["@fakerecipient:example.com"]
|
||||||
)
|
)
|
||||||
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
|
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||||
interested_service, [event]
|
interested_service, ephemeral=[event]
|
||||||
)
|
)
|
||||||
self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
|
self.mock_store.set_appservice_stream_type_pos.assert_called_once_with(
|
||||||
interested_service,
|
interested_service,
|
||||||
"read_receipt",
|
"read_receipt",
|
||||||
580,
|
580,
|
||||||
@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
self.handler.notify_interested_services_ephemeral(
|
self.handler.notify_interested_services_ephemeral(
|
||||||
"receipt_key", 580, ["@fakerecipient:example.com"]
|
"receipt_key", 580, ["@fakerecipient:example.com"]
|
||||||
)
|
)
|
||||||
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
|
# This method will be called, but with an empty list of events
|
||||||
|
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||||
|
interested_service, ephemeral=[]
|
||||||
|
)
|
||||||
|
|
||||||
def _mkservice(self, is_interested, protocols=None):
|
def _mkservice(self, is_interested, protocols=None):
|
||||||
service = Mock()
|
service = Mock()
|
||||||
@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
service.token = "mock_service_token"
|
service.token = "mock_service_token"
|
||||||
service.url = "mock_service_url"
|
service.url = "mock_service_url"
|
||||||
return service
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Tests that the ApplicationServicesHandler sends events to application
|
||||||
|
services correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
sendtodevice.register_servlets,
|
||||||
|
receipts.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
||||||
|
# we can track any outgoing ephemeral events
|
||||||
|
self.send_mock = simple_async_mock()
|
||||||
|
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
|
||||||
|
|
||||||
|
# Mock out application services, and allow defining our own in tests
|
||||||
|
self._services: List[ApplicationService] = []
|
||||||
|
self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
|
||||||
|
|
||||||
|
# A user on the homeserver.
|
||||||
|
self.local_user_device_id = "local_device"
|
||||||
|
self.local_user = self.register_user("local_user", "password")
|
||||||
|
self.local_user_token = self.login(
|
||||||
|
"local_user", "password", self.local_user_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# A user on the homeserver which lies within an appservice's exclusive user namespace.
|
||||||
|
self.exclusive_as_user_device_id = "exclusive_as_device"
|
||||||
|
self.exclusive_as_user = self.register_user("exclusive_as_user", "password")
|
||||||
|
self.exclusive_as_user_token = self.login(
|
||||||
|
"exclusive_as_user", "password", self.exclusive_as_user_device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||||
|
)
|
||||||
|
def test_application_services_receive_local_to_device(self):
|
||||||
|
"""
|
||||||
|
Test that when a user sends a to-device message to another user
|
||||||
|
that is an application service's user namespace, the
|
||||||
|
application service will receive it.
|
||||||
|
"""
|
||||||
|
interested_appservice = self._register_application_service(
|
||||||
|
namespaces={
|
||||||
|
ApplicationService.NS_USERS: [
|
||||||
|
{
|
||||||
|
"regex": "@exclusive_as_user:.+",
|
||||||
|
"exclusive": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Have local_user send a to-device message to exclusive_as_user
|
||||||
|
message_content = {"some_key": "some really interesting value"}
|
||||||
|
chan = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
|
||||||
|
content={
|
||||||
|
"messages": {
|
||||||
|
self.exclusive_as_user: {
|
||||||
|
self.exclusive_as_user_device_id: message_content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
access_token=self.local_user_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
|
# Have exclusive_as_user send a to-device message to local_user
|
||||||
|
chan = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
|
||||||
|
content={
|
||||||
|
"messages": {
|
||||||
|
self.local_user: {self.local_user_device_id: message_content}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
access_token=self.exclusive_as_user_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
|
# Check if our application service - that is interested in exclusive_as_user - received
|
||||||
|
# the to-device message as part of an AS transaction.
|
||||||
|
# Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
|
||||||
|
#
|
||||||
|
# The uninterested application service should not have been notified at all.
|
||||||
|
self.send_mock.assert_called_once()
|
||||||
|
service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
|
||||||
|
|
||||||
|
# Assert that this was the same to-device message that local_user sent
|
||||||
|
self.assertEqual(service, interested_appservice)
|
||||||
|
self.assertEqual(to_device_messages[0]["type"], "m.room_key_request")
|
||||||
|
self.assertEqual(to_device_messages[0]["sender"], self.local_user)
|
||||||
|
|
||||||
|
# Additional fields 'to_user_id' and 'to_device_id' specifically for
|
||||||
|
# to-device messages via the AS API
|
||||||
|
self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
|
||||||
|
self.assertEqual(
|
||||||
|
to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id
|
||||||
|
)
|
||||||
|
self.assertEqual(to_device_messages[0]["content"], message_content)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
|
||||||
|
)
|
||||||
|
def test_application_services_receive_bursts_of_to_device(self):
|
||||||
|
"""
|
||||||
|
Test that when a user sends >100 to-device messages at once, any
|
||||||
|
interested AS's will receive them in separate transactions.
|
||||||
|
|
||||||
|
Also tests that uninterested application services do not receive messages.
|
||||||
|
"""
|
||||||
|
# Register two application services with exclusive interest in a user
|
||||||
|
interested_appservices = []
|
||||||
|
for _ in range(2):
|
||||||
|
appservice = self._register_application_service(
|
||||||
|
namespaces={
|
||||||
|
ApplicationService.NS_USERS: [
|
||||||
|
{
|
||||||
|
"regex": "@exclusive_as_user:.+",
|
||||||
|
"exclusive": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
interested_appservices.append(appservice)
|
||||||
|
|
||||||
|
# ...and an application service which does not have any user interest.
|
||||||
|
self._register_application_service()
|
||||||
|
|
||||||
|
to_device_message_content = {
|
||||||
|
"some key": "some interesting value",
|
||||||
|
}
|
||||||
|
|
||||||
|
# We need to send a large burst of to-device messages. We also would like to
|
||||||
|
# include them all in the same application service transaction so that we can
|
||||||
|
# test large transactions.
|
||||||
|
#
|
||||||
|
# To do this, we can send a single to-device message to many user devices at
|
||||||
|
# once.
|
||||||
|
#
|
||||||
|
# We insert number_of_messages - 1 messages into the database directly. We'll then
|
||||||
|
# send a final to-device message to the real device, which will also kick off
|
||||||
|
# an AS transaction (as just inserting messages into the DB won't).
|
||||||
|
number_of_messages = 150
|
||||||
|
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
|
||||||
|
messages = {
|
||||||
|
self.exclusive_as_user: {
|
||||||
|
device_id: to_device_message_content for device_id in fake_device_ids
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a fake device per message. We can't send to-device messages to
|
||||||
|
# a device that doesn't exist.
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_datastore().db_pool.simple_insert_many(
|
||||||
|
desc="test_application_services_receive_burst_of_to_device",
|
||||||
|
table="devices",
|
||||||
|
keys=("user_id", "device_id"),
|
||||||
|
values=[
|
||||||
|
(
|
||||||
|
self.exclusive_as_user,
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
for device_id in fake_device_ids
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seed the device_inbox table with our fake messages
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now have local_user send a final to-device message to exclusive_as_user. All unsent
|
||||||
|
# to-device messages should be sent to any application services
|
||||||
|
# interested in exclusive_as_user.
|
||||||
|
chan = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
|
||||||
|
content={
|
||||||
|
"messages": {
|
||||||
|
self.exclusive_as_user: {
|
||||||
|
self.exclusive_as_user_device_id: to_device_message_content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
access_token=self.local_user_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
|
self.send_mock.assert_called()
|
||||||
|
|
||||||
|
# Count the total number of to-device messages that were sent out per-service.
|
||||||
|
# Ensure that we only sent to-device messages to interested services, and that
|
||||||
|
# each interested service received the full count of to-device messages.
|
||||||
|
service_id_to_message_count: Dict[str, int] = {}
|
||||||
|
|
||||||
|
for call in self.send_mock.call_args_list:
|
||||||
|
service, _events, _ephemeral, to_device_messages = call[0]
|
||||||
|
|
||||||
|
# Check that this was made to an interested service
|
||||||
|
self.assertIn(service, interested_appservices)
|
||||||
|
|
||||||
|
# Add to the count of messages for this application service
|
||||||
|
service_id_to_message_count.setdefault(service.id, 0)
|
||||||
|
service_id_to_message_count[service.id] += len(to_device_messages)
|
||||||
|
|
||||||
|
# Assert that each interested service received the full count of messages
|
||||||
|
for count in service_id_to_message_count.values():
|
||||||
|
self.assertEqual(count, number_of_messages)
|
||||||
|
|
||||||
|
def _register_application_service(
|
||||||
|
self,
|
||||||
|
namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
|
||||||
|
) -> ApplicationService:
|
||||||
|
"""
|
||||||
|
Register a new application service, with the given namespaces of interest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
namespaces: A dictionary containing any user, room or alias namespaces that
|
||||||
|
the application service is interested in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The registered application service.
|
||||||
|
"""
|
||||||
|
# Create an application service
|
||||||
|
appservice = ApplicationService(
|
||||||
|
token=random_string(10),
|
||||||
|
hostname="example.com",
|
||||||
|
id=random_string(10),
|
||||||
|
sender="@as:example.com",
|
||||||
|
rate_limited=False,
|
||||||
|
namespaces=namespaces,
|
||||||
|
supports_ephemeral=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the application service
|
||||||
|
self._services.append(appservice)
|
||||||
|
|
||||||
|
return appservice
|
||||||
|
@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
service = Mock(id=self.as_list[0]["id"])
|
service = Mock(id=self.as_list[0]["id"])
|
||||||
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
||||||
txn = self.get_success(
|
txn = self.get_success(
|
||||||
defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
|
defer.ensureDeferred(
|
||||||
|
self.store.create_appservice_txn(service, events, [], [])
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEquals(txn.id, 1)
|
self.assertEquals(txn.id, 1)
|
||||||
self.assertEquals(txn.events, events)
|
self.assertEquals(txn.events, events)
|
||||||
@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
|
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
|
||||||
self.get_success(self._insert_txn(service.id, 9644, events))
|
self.get_success(self._insert_txn(service.id, 9644, events))
|
||||||
self.get_success(self._insert_txn(service.id, 9645, events))
|
self.get_success(self._insert_txn(service.id, 9645, events))
|
||||||
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
|
txn = self.get_success(
|
||||||
|
self.store.create_appservice_txn(service, events, [], [])
|
||||||
|
)
|
||||||
self.assertEquals(txn.id, 9646)
|
self.assertEquals(txn.id, 9646)
|
||||||
self.assertEquals(txn.events, events)
|
self.assertEquals(txn.events, events)
|
||||||
self.assertEquals(txn.service, service)
|
self.assertEquals(txn.service, service)
|
||||||
@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
service = Mock(id=self.as_list[0]["id"])
|
service = Mock(id=self.as_list[0]["id"])
|
||||||
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
||||||
self.get_success(self._set_last_txn(service.id, 9643))
|
self.get_success(self._set_last_txn(service.id, 9643))
|
||||||
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
|
txn = self.get_success(
|
||||||
|
self.store.create_appservice_txn(service, events, [], [])
|
||||||
|
)
|
||||||
self.assertEquals(txn.id, 9644)
|
self.assertEquals(txn.id, 9644)
|
||||||
self.assertEquals(txn.events, events)
|
self.assertEquals(txn.events, events)
|
||||||
self.assertEquals(txn.service, service)
|
self.assertEquals(txn.service, service)
|
||||||
@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
|
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
|
||||||
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
|
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
|
||||||
|
|
||||||
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
|
txn = self.get_success(
|
||||||
|
self.store.create_appservice_txn(service, events, [], [])
|
||||||
|
)
|
||||||
self.assertEquals(txn.id, 9644)
|
self.assertEquals(txn.id, 9644)
|
||||||
self.assertEquals(txn.events, events)
|
self.assertEquals(txn.events, events)
|
||||||
self.assertEquals(txn.service, service)
|
self.assertEquals(txn.service, service)
|
||||||
@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
|
|||||||
ValueError,
|
ValueError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_type_stream_id_for_appservice(self) -> None:
|
def test_set_appservice_stream_type_pos(self) -> None:
|
||||||
read_receipt_value = 1024
|
read_receipt_value = 1024
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_type_stream_id_for_appservice(
|
self.store.set_appservice_stream_type_pos(
|
||||||
self.service, "read_receipt", read_receipt_value
|
self.service, "read_receipt", read_receipt_value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(result, read_receipt_value)
|
self.assertEqual(result, read_receipt_value)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_type_stream_id_for_appservice(
|
self.store.set_appservice_stream_type_pos(
|
||||||
self.service, "presence", read_receipt_value
|
self.service, "presence", read_receipt_value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(result, read_receipt_value)
|
self.assertEqual(result, read_receipt_value)
|
||||||
|
|
||||||
def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
|
def test_set_appservice_stream_type_pos_invalid_type(self) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
|
self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024),
|
||||||
ValueError,
|
ValueError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user