Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617)

Co-authored-by: Erik Johnston <erik@matrix.org>
This commit is contained in:
reivilibre 2022-02-24 17:55:45 +00:00 committed by GitHub
parent 41cf4c2cf6
commit 2cc5ea933d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 528 additions and 38 deletions

View file

@ -31,6 +31,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Type for the `device_one_time_key_counts` field in an appservice transaction
# user ID -> {device ID -> {algorithm -> count}}
TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]]
# Type for the `device_unused_fallback_keys` field in an appservice transaction
# user ID -> {device ID -> [algorithm]}
TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]]
class ApplicationServiceState(Enum):
DOWN = "down"
@ -72,6 +80,7 @@ class ApplicationService:
rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None,
supports_ephemeral: bool = False,
msc3202_transaction_extensions: bool = False,
):
self.token = token
self.url = (
@ -84,6 +93,7 @@ class ApplicationService:
self.id = id
self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral
self.msc3202_transaction_extensions = msc3202_transaction_extensions
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@ -339,12 +349,16 @@ class AppServiceTransaction:
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
):
self.service = service
self.id = id
self.events = events
self.ephemeral = ephemeral
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@ -359,6 +373,8 @@ class AppServiceTransaction:
events=self.events,
ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
txn_id=self.id,
)

View file

@ -19,6 +19,11 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient):
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
txn_id: Optional[int] = None,
) -> bool:
"""
@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it
body: Dict[str, List[JsonDict]] = {"events": serialized_events}
body: JsonDict = {"events": serialized_events}
if service.supports_ephemeral:
body.update(
{
@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
"org.matrix.msc3202.device_one_time_key_counts"
] = one_time_key_counts
if unused_fallback_keys:
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
try:
await self.put_json(
uri=uri,

View file

@ -54,12 +54,19 @@ from typing import (
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.appservice.api import ApplicationServiceApi
from synapse.events import EventBase
from synapse.logging.context import run_in_background
@ -96,7 +103,7 @@ class ApplicationServiceScheduler:
self.as_api = hs.get_application_service_api()
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs)
async def start(self) -> None:
logger.info("Starting appservice scheduler")
@ -153,7 +160,9 @@ class _ServiceQueuer:
appservice at a given time.
"""
def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
def __init__(
self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer"
):
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
@ -165,6 +174,10 @@ class _ServiceQueuer:
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
self._msc3202_transaction_extensions_enabled: bool = (
hs.config.experimental.msc3202_transaction_extensions
)
self._store = hs.get_datastores().main
def start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one
@ -202,15 +215,84 @@ class _ServiceQueuer:
if not events and not ephemeral and not to_device_messages_to_send:
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None
if (
self._msc3202_transaction_extensions_enabled
and service.msc3202_transaction_extensions
):
# Compute the one-time key counts and fallback key usage states
# for the users which are mentioned in this transaction,
# as well as the appservice's sender.
(
one_time_key_counts,
unused_fallback_keys,
) = await self._compute_msc3202_otk_counts_and_fallback_keys(
service, events, ephemeral, to_device_messages_to_send
)
try:
await self.txn_ctrl.send(
service, events, ephemeral, to_device_messages_to_send
service,
events,
ephemeral,
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
)
except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
async def _compute_msc3202_otk_counts_and_fallback_keys(
self,
service: ApplicationService,
events: Iterable[EventBase],
ephemerals: Iterable[JsonDict],
to_device_messages: Iterable[JsonDict],
) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
- first computes a list of application services users that may have
interesting updates to the one-time key counts or fallback key usage.
- then computes one-time key counts and fallback key usages for those users.
Given a list of application service users that are interesting,
compute one-time key counts and fallback key usages for the users.
"""
# Set of 'interesting' users who may have updates
users: Set[str] = set()
# The sender is always included
users.add(service.sender)
# All AS users that would receive the PDUs or EDUs sent to these rooms
# are classed as 'interesting'.
rooms_of_interesting_users: Set[str] = set()
# PDUs
rooms_of_interesting_users.update(event.room_id for event in events)
# EDUs
rooms_of_interesting_users.update(
ephemeral["room_id"] for ephemeral in ephemerals
)
# Look up the AS users in those rooms
for room_id in rooms_of_interesting_users:
users.update(
await self._store.get_app_service_users_in_room(room_id, service)
)
# Add recipients of to-device messages.
# device_message["user_id"] is the ID of the recipient.
users.update(device_message["user_id"] for device_message in to_device_messages)
# Compute and return the counts / fallback key usage states
otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users)
return otk_counts, unused_fbks
class _TransactionController:
"""Transaction manager.
@ -238,6 +320,8 @@ class _TransactionController:
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@ -248,6 +332,10 @@ class _TransactionController:
events: The persistent events to include in the transaction.
ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@ -255,6 +343,8 @@ class _TransactionController:
events=events,
ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View file

@ -166,6 +166,16 @@ def _load_appservice(
supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
# Opt-in flag for the MSC3202-specific transactional behaviour.
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(
"The `org.matrix.msc3202` option should be true or false if specified."
)
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@ -174,8 +184,9 @@ def _load_appservice(
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols,
rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,
supports_ephemeral=supports_ephemeral,
msc3202_transaction_extensions=msc3202_transaction_extensions,
)

View file

@ -47,11 +47,6 @@ class ExperimentalConfig(Config):
# MSC3030 (Jump to date API endpoint)
self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
# The portion of MSC3202 which is related to device masquerading.
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# MSC2409 (this setting only relates to optionally sending to-device messages).
# Presence, typing and read receipt EDUs are already sent to application services that
# have opted in to receive them. If enabled, this adds to-device messages to that list.
@ -59,6 +54,17 @@ class ExperimentalConfig(Config):
"msc2409_to_device_messages_enabled", False
)
# The portion of MSC3202 which is related to device masquerading.
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)
# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

View file

@ -20,14 +20,18 @@ from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -56,7 +60,7 @@ def _make_exclusive_regex(
return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return service
return None
@cached(iterable=True, cache_context=True)
async def get_app_service_users_in_room(
self,
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
users_in_room = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
return list(filter(app_service.is_interested_in_user, users_in_room))
class ApplicationServiceStore(ApplicationServiceWorkerStore):
# This is currently empty due to there not being any AS storage functions
@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore(
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
Returns:
A new transaction.
@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
)
return await self.db_pool.runInteraction(
@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys
# are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
id=entry["txn_id"],
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:

View file

@ -29,6 +29,10 @@ import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
) -> TransactionOneTimeKeyCounts:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
transactions.
Return structure is of the shape:
user_id -> device_id -> algorithm -> count
Empty algorithm -> count dicts are created if needed to represent a
lack of unused one-time keys.
"""
def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
) -> TransactionOneTimeKeyCounts:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql = f"""
SELECT user_id, device_id, algorithm, COUNT(key_id)
FROM devices
LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
WHERE {user_in_where_clause}
GROUP BY user_id, device_id, algorithm
"""
txn.execute(sql, user_parameters)
result: TransactionOneTimeKeyCounts = {}
for user_id, device_id, algorithm, count in txn:
# We deliberately construct empty dictionaries for
# users and devices without any unused one-time keys.
# We *could* omit these empty dicts if there have been no
# changes since the last transaction, but we currently don't
# do any change tracking!
device_count_by_algo = result.setdefault(user_id, {}).setdefault(
device_id, {}
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_count_by_algo[algorithm] = count
return result
return await self.db_pool.runInteraction(
"count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
)
async def get_e2e_bulk_unused_fallback_key_types(
self, user_ids: Collection[str]
) -> TransactionUnusedFallbackKeys:
"""
Finds, in bulk, the types of unused fallback keys for all the users specified.
Intended to be used by application services for populating unused fallback
keys in transactions.
Return structure is of the shape:
user_id -> device_id -> algorithms
Empty lists are created for devices if there are no unused fallback
keys. This matches the response structure of MSC3202.
"""
if len(user_ids) == 0:
return {}
def _get_bulk_e2e_unused_fallback_keys_txn(
txn: LoggingTransaction,
) -> TransactionUnusedFallbackKeys:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "devices.user_id", user_ids
)
# We can't use USING here because we require the `.used` condition
# to be part of the JOIN condition so that we generate empty lists
# when all keys are used (as opposed to just when there are no keys at all).
sql = f"""
SELECT devices.user_id, devices.device_id, algorithm
FROM devices
LEFT JOIN e2e_fallback_keys_json AS fallback_keys
ON devices.user_id = fallback_keys.user_id
AND devices.device_id = fallback_keys.device_id
AND NOT fallback_keys.used
WHERE
{user_in_where_clause}
"""
txn.execute(sql, user_parameters)
result: TransactionUnusedFallbackKeys = {}
for user_id, device_id, algorithm in txn:
# We deliberately construct empty dictionaries and lists for
# users and devices without any unused fallback keys.
# We *could* omit these empty dicts if there have been no
# changes since the last transaction, but we currently don't
# do any change tracking!
device_unused_keys = result.setdefault(user_id, {}).setdefault(
device_id, []
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_unused_keys.append(algorithm)
return result
return await self.db_pool.runInteraction(
"_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn
)
async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None: