Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314)

Experimental support for MSC3983 is behind a configuration flag.
If enabled, for users which are exclusively owned by an application
service then the appservice will be queried for one-time keys *if*
there are none uploaded to Synapse.
This commit is contained in:
Patrick Cloke 2023-03-28 14:26:27 -04:00 committed by GitHub
parent 57481ca694
commit 5282ba1e2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 354 additions and 28 deletions

View File

@ -0,0 +1 @@
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).

View File

@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc() failed_transactions_counter.labels(service.id).inc()
return False return False
async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from an application service.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
A copy of the input which has not been fulfilled because the
appservice doesn't support this endpoint or has not returned
data for that tuple.
"""
if service.url is None:
return {}, query
# This is required by the configuration.
assert service.hs_token is not None
# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
response = await self.post_json_get_json(
uri,
body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
except CodeMessageException as e:
# The appservice doesn't support this endpoint.
if e.code == 404 or e.code == 405:
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
except Exception as ex:
logger.warning("claim_keys to %s threw exception %s", uri, ex)
return {}, query
# Check if the appservice fulfilled all of the queried user/device/algorithms
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = [
(user_id, device, algorithm)
for user_id, device, algorithm in query
if algorithm not in response.get(user_id, {}).get(device, [])
]
return response, missing
def _serialize( def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase] self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]: ) -> List[JsonDict]:

View File

@ -74,6 +74,11 @@ class ExperimentalConfig(Config):
"msc3202_transaction_extensions", False "msc3202_transaction_extensions", False
) )
# MSC3983: Proxying OTK claim requests to exclusive ASes.
self.msc3983_appservice_otk_claims: bool = experimental.get(
"msc3983_appservice_otk_claims", False
)
# MSC3706 (server-side support for partial state in /send_join responses) # MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable # Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve # query parameter `omit_members`. If this flag is set, Synapse will also serve

View File

@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name from synapse.util.stringutils import parse_server_name
@ -135,6 +135,7 @@ class FederationServer(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._e2e_keys_handler = hs.get_e2e_keys_handler()
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
@ -1012,15 +1013,14 @@ class FederationServer(FederationBase):
query.append((user_id, device_id, algorithm)) query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query) results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
json_result: Dict[str, Dict[str, dict]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_keys in results.items(): for result in results:
for device_id, keys in device_keys.items(): for user_id, device_keys in result.items():
for key_id, json_str in keys.items(): for device_id, keys in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = { for key_id, key in keys.items():
key_id: json_decoder.decode(json_str) json_result.setdefault(user_id, {})[device_id] = {key_id: key}
}
logger.info( logger.info(
"Claimed one-time-keys: %s", "Claimed one-time-keys: %s",

View File

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from prometheus_client import Counter from prometheus_client import Counter
@ -829,3 +838,66 @@ class ApplicationServicesHandler:
if unknown_user: if unknown_user:
return await self.query_user_exists(user_id) return await self.query_user_exists(user_id)
return True return True
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
"""Claim one time keys from application services.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
providing OTKs).
"""
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
missing = []
for user_id, device, algorithm in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
)
continue
# Query each service in parallel.
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
self.appservice_api.claim_client_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
service_query,
)
for service_id, service_query in query_by_appservice.items()
],
consumeErrors=True,
)
)
# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
for success, result in results:
if success:
claimed_keys.append(result[0])
missing.extend(result[1])
return claimed_keys, missing

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
@ -53,6 +52,7 @@ class E2eKeysHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -88,6 +88,10 @@ class E2eKeysHandler:
max_count=10, max_count=10,
) )
self._query_appservices_for_otks = (
hs.config.experimental.msc3983_appservice_otk_claims
)
@trace @trace
@cancellable @cancellable
async def query_devices( async def query_devices(
@ -542,6 +546,42 @@ class E2eKeysHandler:
return ret return ret
async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
1. Attempt to claim OTKs from the database.
2. Ask application services if they provide OTKs.
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
@ -561,17 +601,18 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query)) set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries)) set_tag("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query) results = await self.claim_local_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key. # A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
# Remote failures.
failures: Dict[str, JsonDict] = {} failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_str)
}
@trace @trace
async def claim_client_keys(destination: str) -> None: async def claim_client_keys(destination: str) -> None:

View File

@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def claim_e2e_one_time_keys( async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]] self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, str]]]: ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Take a list of one time keys out of the database. """Take a list of one time keys out of the database.
Args: Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm). query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns: Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. A tuple pf:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled.
""" """
@trace @trace
@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key_id, key_json = otk_row key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json return f"{algorithm}:{key_id}", key_json
results: Dict[str, Dict[str, Dict[str, str]]] = {} results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str]] = []
for user_id, device_id, algorithm in query_list: for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning: if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that # If we support RETURNING clause we can use a single query that
@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_results = results.setdefault(user_id, {}).setdefault( device_results = results.setdefault(user_id, {}).setdefault(
device_id, {} device_id, {}
) )
device_results[claim_row[0]] = claim_row[1] device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
continue else:
missing.append((user_id, device_id, algorithm))
# No one-time key available, so see if there's a fallback return results, missing
# key
async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
row = await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json", table="e2e_fallback_keys_json",
keyvalues={ keyvalues={
@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
) )
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = key_json device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
return results return results

View File

@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(self.request_url, URL_LOCATION) self.assertEqual(self.request_url, URL_LOCATION)
self.assertEqual(result, SUCCESS_RESULT_LOCATION) self.assertEqual(result, SUCCESS_RESULT_LOCATION)
def test_claim_keys(self) -> None:
"""
Tests that the /keys/claim response is properly parsed for missing
keys.
"""
RESPONSE: JsonDict = {
"@alice:example.org": {
"DEVICE_1": {
"signed_curve25519:AAAAHg": {
# We don't really care about the content of the keys,
# they get passed back transparently.
},
"signed_curve25519:BBBBHg": {},
},
"DEVICE_2": {"signed_curve25519:CCCCHg": {}},
},
}
async def post_json_get_json(
uri: str,
post_json: Any,
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> JsonDict:
# Ensure the access token is passed as both a header and query arg.
if not headers.get("Authorization"):
raise RuntimeError("Access token not provided")
self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
return RESPONSE
# We assign to a method, which mypy doesn't like.
self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment]
MISSING_KEYS = [
# Known user, known device, missing algorithm.
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
# Known user, missing device.
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
# Unknown user.
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
]
claimed_keys, missing = self.get_success(
self.api.claim_client_keys(
self.service,
[
# Found devices
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
]
+ MISSING_KEYS,
)
)
self.assertEqual(claimed_keys, RESPONSE)
self.assertEqual(missing, MISSING_KEYS)

View File

@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock()) self.appservice_api = mock.Mock()
return self.setup_test_homeserver(
federation_client=mock.Mock(), application_service_api=self.appservice_api
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler() self.handler = hs.get_e2e_keys_handler()
@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# The two requests to the local homeserver should be identical. # The two requests to the local homeserver should be identical.
self.assertEqual(response_1, response_2) self.assertEqual(response_1, response_2)
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
def test_query_appservice(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id_1 = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
device_id_2 = "abc"
otk = {"alg1:k2": "key2"}
# Inject an appservice interested in this user.
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
[appservice]
)
# Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
)
# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(res, [])
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should ask the appservice, then
# query the fallback keys.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{
"one_time_keys": {
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
}
},
timeout=None,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: fallback_key, device_id_2: otk}
},
},
)