Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617)

Co-authored-by: Erik Johnston <erik@matrix.org>
This commit is contained in:
reivilibre 2022-02-24 17:55:45 +00:00 committed by GitHub
parent 41cf4c2cf6
commit 2cc5ea933d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 528 additions and 38 deletions

View file

@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@ -92,6 +94,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@ -114,7 +118,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[], to_device_messages=[]
service=service,
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4)
event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
def test_send_single_event_with_queue(self):
d = defer.Deferred()
@ -231,11 +240,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [])
self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], None, None
)
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@ -261,15 +272,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@ -288,13 +299,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], None, None
)
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], None, None
)
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], None, None
)
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
@ -302,14 +319,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
)
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
)
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
@ -324,13 +345,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, []
service, [], event_list_2 + event_list_3, [], None, None
)
self.assertEquals(2, self.txn_ctrl.send.call_count)
@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], None, None
)
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View file

@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
from synapse.appservice import ApplicationService
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@ -428,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
#
# The uninterested application service should not have been notified at all.
self.send_mock.assert_called_once()
service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
(
service,
_events,
_ephemeral,
to_device_messages,
_otks,
_fbks,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
self.assertEqual(service, interested_appservice)
@ -540,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages = call[0]
service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
@ -582,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self._services.append(appservice)
return appservice
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4
ARG_FALLBACK_KEYS = 5
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
register.register_servlets,
room.register_servlets,
sendtodevice.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
self.send_mock = simple_async_mock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
# Define an application service for the tests
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
"as1.invalid",
"as1",
"@as.sender:test",
namespaces={
"users": [
{"regex": "@_as_.*:test", "exclusive": True},
{"regex": "@as.sender:test", "exclusive": True},
]
},
msc3202_transaction_extensions=True,
)
self.hs.get_datastores().main.services_cache = [self._service]
# Register some appservice users
self._sender_user, self._sender_device = self.register_appservice_user(
"as.sender", self._service_token
)
self._namespaced_user, self._namespaced_device = self.register_appservice_user(
"_as_user1", self._service_token
)
# Register a real user as well.
self._real_user = self.register_user("real.user", "meow")
self._real_user_token = self.login("real.user", "meow")
async def _add_otks_for_device(
self, user_id: str, device_id: str, otk_count: int
) -> None:
"""
Add some dummy keys. It doesn't matter if they're not a real algorithm;
that should be opaque to the server anyway.
"""
await self.hs.get_datastores().main.add_e2e_one_time_keys(
user_id,
device_id,
self.clock.time_msec(),
[("algo", f"k{i}", "{}") for i in range(otk_count)],
)
async def _add_fallback_key_for_device(
self, user_id: str, device_id: str, used: bool
) -> None:
"""
Adds a fake fallback key to a device, optionally marking it as used
right away.
"""
store = self.hs.get_datastores().main
await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"})
if used is True:
# Mark the key as used
await store.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": "algo",
"key_id": "fk",
},
updatevalues={"used": True},
desc="_get_fallback_key_set_used",
)
def _set_up_devices_and_a_room(self) -> str:
"""
Helper to set up devices for all the users
and a room for the users to talk in.
"""
async def preparation():
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True
)
await self._add_otks_for_device(
self._namespaced_user, self._namespaced_device, 36
)
await self._add_fallback_key_for_device(
self._namespaced_user, self._namespaced_device, used=False
)
# Register a device for the real user, too, so that we can later ensure
# that we don't leak information to the AS about the non-AS user.
await self.hs.get_datastores().main.store_device(
self._real_user, "REALDEV", "UltraMatrix 3000"
)
await self._add_otks_for_device(self._real_user, "REALDEV", 50)
self.get_success(preparation())
room_id = self.helper.create_room_as(
self._real_user, is_public=True, tok=self._real_user_token
)
self.helper.join(
room_id,
self._namespaced_user,
tok=self._service_token,
appservice_user_id=self._namespaced_user,
)
# Check it was called for sanity. (This was to send the join event to the AS.)
self.send_mock.assert_called()
self.send_mock.reset_mock()
return room_id
@override_config(
{"experimental_features": {"msc3202_transaction_extensions": True}}
)
def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus(
self,
) -> None:
"""
Tests that:
- the AS receives one-time key counts and unused fallback keys for:
- the specified sender; and
- any user who is in receipt of the PDUs
"""
room_id = self._set_up_devices_and_a_room()
# Send a message into the AS's room
self.helper.send(room_id, "woof woof", tok=self._real_user_token)
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
self.assertEqual(
otks,
{
"@as.sender:test": {self._sender_device: {"algo": 42}},
"@_as_user1:test": {self._namespaced_device: {"algo": 36}},
},
)
self.assertEqual(
unused_fallbacks,
{
"@as.sender:test": {self._sender_device: []},
"@_as_user1:test": {self._namespaced_device: ["algo"]},
},
)

View file

@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [])
self.store.create_appservice_txn(service, events, [], [], {}, {})
)
)
self.assertEquals(txn.id, 1)
@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
self.store.create_appservice_txn(service, events, [], [], {}, {})
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
self.store.create_appservice_txn(service, events, [], [], {}, {})
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
self.store.create_appservice_txn(service, events, [], [], {}, {})
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)