mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314)
Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse.
This commit is contained in:
parent
57481ca694
commit
5282ba1e2b
1
changelog.d/15314.feature
Normal file
1
changelog.d/15314.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
|
@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
failed_transactions_counter.labels(service.id).inc()
|
failed_transactions_counter.labels(service.id).inc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def claim_client_keys(
|
||||||
|
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
|
||||||
|
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||||
|
"""Claim one time keys from an application service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
|
||||||
|
|
||||||
|
A copy of the input which has not been fulfilled because the
|
||||||
|
appservice doesn't support this endpoint or has not returned
|
||||||
|
data for that tuple.
|
||||||
|
"""
|
||||||
|
if service.url is None:
|
||||||
|
return {}, query
|
||||||
|
|
||||||
|
# This is required by the configuration.
|
||||||
|
assert service.hs_token is not None
|
||||||
|
|
||||||
|
# Create the expected payload shape.
|
||||||
|
body: Dict[str, Dict[str, List[str]]] = {}
|
||||||
|
for user_id, device, algorithm in query:
|
||||||
|
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
|
||||||
|
|
||||||
|
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
|
||||||
|
try:
|
||||||
|
response = await self.post_json_get_json(
|
||||||
|
uri,
|
||||||
|
body,
|
||||||
|
headers={"Authorization": [f"Bearer {service.hs_token}"]},
|
||||||
|
)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
# The appservice doesn't support this endpoint.
|
||||||
|
if e.code == 404 or e.code == 405:
|
||||||
|
return {}, query
|
||||||
|
logger.warning("claim_keys to %s received %s", uri, e.code)
|
||||||
|
return {}, query
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("claim_keys to %s threw exception %s", uri, ex)
|
||||||
|
return {}, query
|
||||||
|
|
||||||
|
# Check if the appservice fulfilled all of the queried user/device/algorithms
|
||||||
|
# or if some are still missing.
|
||||||
|
#
|
||||||
|
# TODO This places a lot of faith in the response shape being correct.
|
||||||
|
missing = [
|
||||||
|
(user_id, device, algorithm)
|
||||||
|
for user_id, device, algorithm in query
|
||||||
|
if algorithm not in response.get(user_id, {}).get(device, [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return response, missing
|
||||||
|
|
||||||
def _serialize(
|
def _serialize(
|
||||||
self, service: "ApplicationService", events: Iterable[EventBase]
|
self, service: "ApplicationService", events: Iterable[EventBase]
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
|
@ -74,6 +74,11 @@ class ExperimentalConfig(Config):
|
|||||||
"msc3202_transaction_extensions", False
|
"msc3202_transaction_extensions", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MSC3983: Proxying OTK claim requests to exclusive ASes.
|
||||||
|
self.msc3983_appservice_otk_claims: bool = experimental.get(
|
||||||
|
"msc3983_appservice_otk_claims", False
|
||||||
|
)
|
||||||
|
|
||||||
# MSC3706 (server-side support for partial state in /send_join responses)
|
# MSC3706 (server-side support for partial state in /send_join responses)
|
||||||
# Synapse will always serve partial state responses to requests using the stable
|
# Synapse will always serve partial state responses to requests using the stable
|
||||||
# query parameter `omit_members`. If this flag is set, Synapse will also serve
|
# query parameter `omit_members`. If this flag is set, Synapse will also serve
|
||||||
|
@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock
|
|||||||
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
|
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
|
||||||
from synapse.storage.roommember import MemberSummary
|
from synapse.storage.roommember import MemberSummary
|
||||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||||
from synapse.util import json_decoder, unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
|
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import parse_server_name
|
from synapse.util.stringutils import parse_server_name
|
||||||
@ -135,6 +135,7 @@ class FederationServer(FederationBase):
|
|||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self._room_member_handler = hs.get_room_member_handler()
|
self._room_member_handler = hs.get_room_member_handler()
|
||||||
|
self._e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
self._state_storage_controller = hs.get_storage_controllers().state
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
|
|
||||||
@ -1012,15 +1013,14 @@ class FederationServer(FederationBase):
|
|||||||
query.append((user_id, device_id, algorithm))
|
query.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||||
results = await self.store.claim_e2e_one_time_keys(query)
|
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
|
||||||
|
|
||||||
json_result: Dict[str, Dict[str, dict]] = {}
|
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
for user_id, device_keys in results.items():
|
for result in results:
|
||||||
for device_id, keys in device_keys.items():
|
for user_id, device_keys in result.items():
|
||||||
for key_id, json_str in keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
json_result.setdefault(user_id, {})[device_id] = {
|
for key_id, key in keys.items():
|
||||||
key_id: json_decoder.decode(json_str)
|
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Claimed one-time-keys: %s",
|
"Claimed one-time-keys: %s",
|
||||||
|
@ -12,7 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
@ -829,3 +838,66 @@ class ApplicationServicesHandler:
|
|||||||
if unknown_user:
|
if unknown_user:
|
||||||
return await self.query_user_exists(user_id)
|
return await self.query_user_exists(user_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def claim_e2e_one_time_keys(
|
||||||
|
self, query: Iterable[Tuple[str, str, str]]
|
||||||
|
) -> Tuple[
|
||||||
|
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
|
||||||
|
]:
|
||||||
|
"""Claim one time keys from application services.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||||
|
|
||||||
|
A copy of the input which has not been fulfilled (either because
|
||||||
|
they are not appservice users or the appservice does not support
|
||||||
|
providing OTKs).
|
||||||
|
"""
|
||||||
|
services = self.store.get_app_services()
|
||||||
|
|
||||||
|
# Partition the users by appservice.
|
||||||
|
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
|
||||||
|
missing = []
|
||||||
|
for user_id, device, algorithm in query:
|
||||||
|
if not self.store.get_if_app_services_interested_in_user(user_id):
|
||||||
|
missing.append((user_id, device, algorithm))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find the associated appservice.
|
||||||
|
for service in services:
|
||||||
|
if service.is_exclusive_user(user_id):
|
||||||
|
query_by_appservice.setdefault(service.id, []).append(
|
||||||
|
(user_id, device, algorithm)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Query each service in parallel.
|
||||||
|
results = await make_deferred_yieldable(
|
||||||
|
defer.DeferredList(
|
||||||
|
[
|
||||||
|
run_in_background(
|
||||||
|
self.appservice_api.claim_client_keys,
|
||||||
|
# We know this must be an app service.
|
||||||
|
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
|
||||||
|
service_query,
|
||||||
|
)
|
||||||
|
for service_id, service_query in query_by_appservice.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch together the results -- they are all independent (since they
|
||||||
|
# require exclusive control over the users). They get returned as a list
|
||||||
|
# and the caller combines them.
|
||||||
|
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
|
||||||
|
for success, result in results:
|
||||||
|
if success:
|
||||||
|
claimed_keys.append(result[0])
|
||||||
|
missing.extend(result[1])
|
||||||
|
|
||||||
|
return claimed_keys, missing
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
@ -53,6 +52,7 @@ class E2eKeysHandler:
|
|||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.federation = hs.get_federation_client()
|
self.federation = hs.get_federation_client()
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
self._appservice_handler = hs.get_application_service_handler()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@ -88,6 +88,10 @@ class E2eKeysHandler:
|
|||||||
max_count=10,
|
max_count=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._query_appservices_for_otks = (
|
||||||
|
hs.config.experimental.msc3983_appservice_otk_claims
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@cancellable
|
@cancellable
|
||||||
async def query_devices(
|
async def query_devices(
|
||||||
@ -542,6 +546,42 @@ class E2eKeysHandler:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
async def claim_local_one_time_keys(
|
||||||
|
self, local_query: List[Tuple[str, str, str]]
|
||||||
|
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||||
|
"""Claim one time keys for local users.
|
||||||
|
|
||||||
|
1. Attempt to claim OTKs from the database.
|
||||||
|
2. Ask application services if they provide OTKs.
|
||||||
|
3. Attempt to fetch fallback keys from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||||
|
|
||||||
|
# If the application services have not provided any keys via the C-S
|
||||||
|
# API, query it directly for one-time keys.
|
||||||
|
if self._query_appservices_for_otks:
|
||||||
|
(
|
||||||
|
appservice_results,
|
||||||
|
not_found,
|
||||||
|
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||||
|
else:
|
||||||
|
appservice_results = []
|
||||||
|
|
||||||
|
# For each user that does not have a one-time keys available, see if
|
||||||
|
# there is a fallback key.
|
||||||
|
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
|
||||||
|
|
||||||
|
# Return the results in order, each item from the input query should
|
||||||
|
# only appear once in the combined list.
|
||||||
|
return (otk_results, *appservice_results, fallback_results)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def claim_one_time_keys(
|
async def claim_one_time_keys(
|
||||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
||||||
@ -561,17 +601,18 @@ class E2eKeysHandler:
|
|||||||
set_tag("local_key_query", str(local_query))
|
set_tag("local_key_query", str(local_query))
|
||||||
set_tag("remote_key_query", str(remote_queries))
|
set_tag("remote_key_query", str(remote_queries))
|
||||||
|
|
||||||
results = await self.store.claim_e2e_one_time_keys(local_query)
|
results = await self.claim_local_one_time_keys(local_query)
|
||||||
|
|
||||||
# A map of user ID -> device ID -> key ID -> key.
|
# A map of user ID -> device ID -> key ID -> key.
|
||||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
|
for result in results:
|
||||||
|
for user_id, device_keys in result.items():
|
||||||
|
for device_id, keys in device_keys.items():
|
||||||
|
for key_id, key in keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
|
||||||
|
|
||||||
|
# Remote failures.
|
||||||
failures: Dict[str, JsonDict] = {}
|
failures: Dict[str, JsonDict] = {}
|
||||||
for user_id, device_keys in results.items():
|
|
||||||
for device_id, keys in device_keys.items():
|
|
||||||
for key_id, json_str in keys.items():
|
|
||||||
json_result.setdefault(user_id, {})[device_id] = {
|
|
||||||
key_id: json_decoder.decode(json_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def claim_client_keys(destination: str) -> None:
|
async def claim_client_keys(destination: str) -> None:
|
||||||
|
@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
|||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
|
|
||||||
async def claim_e2e_one_time_keys(
|
async def claim_e2e_one_time_keys(
|
||||||
self, query_list: Iterable[Tuple[str, str, str]]
|
self, query_list: Iterable[Tuple[str, str, str]]
|
||||||
) -> Dict[str, Dict[str, Dict[str, str]]]:
|
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||||
"""Take a list of one time keys out of the database.
|
"""Take a list of one time keys out of the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
A tuple pf:
|
||||||
|
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||||
|
|
||||||
|
A copy of the input which has not been fulfilled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
key_id, key_json = otk_row
|
key_id, key_json = otk_row
|
||||||
return f"{algorithm}:{key_id}", key_json
|
return f"{algorithm}:{key_id}", key_json
|
||||||
|
|
||||||
results: Dict[str, Dict[str, Dict[str, str]]] = {}
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
|
missing: List[Tuple[str, str, str]] = []
|
||||||
for user_id, device_id, algorithm in query_list:
|
for user_id, device_id, algorithm in query_list:
|
||||||
if self.database_engine.supports_returning:
|
if self.database_engine.supports_returning:
|
||||||
# If we support RETURNING clause we can use a single query that
|
# If we support RETURNING clause we can use a single query that
|
||||||
@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
device_results = results.setdefault(user_id, {}).setdefault(
|
device_results = results.setdefault(user_id, {}).setdefault(
|
||||||
device_id, {}
|
device_id, {}
|
||||||
)
|
)
|
||||||
device_results[claim_row[0]] = claim_row[1]
|
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||||
continue
|
else:
|
||||||
|
missing.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
# No one-time key available, so see if there's a fallback
|
return results, missing
|
||||||
# key
|
|
||||||
|
async def claim_e2e_fallback_keys(
|
||||||
|
self, query_list: Iterable[Tuple[str, str, str]]
|
||||||
|
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||||
|
"""Take a list of fallback keys out of the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||||
|
"""
|
||||||
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
|
for user_id, device_id, algorithm in query_list:
|
||||||
row = await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
table="e2e_fallback_keys_json",
|
table="e2e_fallback_keys_json",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
)
|
)
|
||||||
|
|
||||||
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
|
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
|
||||||
device_results[f"{algorithm}:{key_id}"] = key_json
|
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(self.request_url, URL_LOCATION)
|
self.assertEqual(self.request_url, URL_LOCATION)
|
||||||
self.assertEqual(result, SUCCESS_RESULT_LOCATION)
|
self.assertEqual(result, SUCCESS_RESULT_LOCATION)
|
||||||
|
|
||||||
|
def test_claim_keys(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that the /keys/claim response is properly parsed for missing
|
||||||
|
keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE: JsonDict = {
|
||||||
|
"@alice:example.org": {
|
||||||
|
"DEVICE_1": {
|
||||||
|
"signed_curve25519:AAAAHg": {
|
||||||
|
# We don't really care about the content of the keys,
|
||||||
|
# they get passed back transparently.
|
||||||
|
},
|
||||||
|
"signed_curve25519:BBBBHg": {},
|
||||||
|
},
|
||||||
|
"DEVICE_2": {"signed_curve25519:CCCCHg": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def post_json_get_json(
|
||||||
|
uri: str,
|
||||||
|
post_json: Any,
|
||||||
|
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
|
||||||
|
) -> JsonDict:
|
||||||
|
# Ensure the access token is passed as both a header and query arg.
|
||||||
|
if not headers.get("Authorization"):
|
||||||
|
raise RuntimeError("Access token not provided")
|
||||||
|
|
||||||
|
self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
|
||||||
|
return RESPONSE
|
||||||
|
|
||||||
|
# We assign to a method, which mypy doesn't like.
|
||||||
|
self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment]
|
||||||
|
|
||||||
|
MISSING_KEYS = [
|
||||||
|
# Known user, known device, missing algorithm.
|
||||||
|
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
|
||||||
|
# Known user, missing device.
|
||||||
|
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
|
||||||
|
# Unknown user.
|
||||||
|
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
|
||||||
|
]
|
||||||
|
|
||||||
|
claimed_keys, missing = self.get_success(
|
||||||
|
self.api.claim_client_keys(
|
||||||
|
self.service,
|
||||||
|
[
|
||||||
|
# Found devices
|
||||||
|
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
|
||||||
|
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
|
||||||
|
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
|
||||||
|
]
|
||||||
|
+ MISSING_KEYS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(claimed_keys, RESPONSE)
|
||||||
|
self.assertEqual(missing, MISSING_KEYS)
|
||||||
|
@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||||||
|
|
||||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.device import DeviceHandler
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(federation_client=mock.Mock())
|
self.appservice_api = mock.Mock()
|
||||||
|
return self.setup_test_homeserver(
|
||||||
|
federation_client=mock.Mock(), application_service_api=self.appservice_api
|
||||||
|
)
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_e2e_keys_handler()
|
self.handler = hs.get_e2e_keys_handler()
|
||||||
@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# The two requests to the local homeserver should be identical.
|
# The two requests to the local homeserver should be identical.
|
||||||
self.assertEqual(response_1, response_2)
|
self.assertEqual(response_1, response_2)
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
|
||||||
|
def test_query_appservice(self) -> None:
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id_1 = "xyz"
|
||||||
|
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||||
|
device_id_2 = "abc"
|
||||||
|
otk = {"alg1:k2": "key2"}
|
||||||
|
|
||||||
|
# Inject an appservice interested in this user.
|
||||||
|
appservice = ApplicationService(
|
||||||
|
token="i_am_an_app_service",
|
||||||
|
id="1234",
|
||||||
|
namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
|
||||||
|
# Note: this user does not have to match the regex above
|
||||||
|
sender="@as_main:test",
|
||||||
|
)
|
||||||
|
self.hs.get_datastores().main.services_cache = [appservice]
|
||||||
|
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
|
||||||
|
[appservice]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup a response, but only for device 2.
|
||||||
|
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||||
|
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
|
||||||
|
)
|
||||||
|
|
||||||
|
# we shouldn't have any unused fallback keys yet
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(res, [])
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user,
|
||||||
|
device_id_1,
|
||||||
|
{"fallback_keys": fallback_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# we should now have an unused alg1 key
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# claiming an OTK when no OTKs are available should ask the appservice, then
|
||||||
|
# query the fallback keys.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
timeout=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {device_id_1: fallback_key, device_id_2: otk}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user