mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
562 lines
20 KiB
Python
562 lines
20 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
# Copyright 2015, 2016 OpenMarket Ltd
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
import logging
|
|
import urllib.parse
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from prometheus_client import Counter
|
|
from typing_extensions import ParamSpec, TypeGuard
|
|
|
|
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
|
|
from synapse.api.errors import CodeMessageException, HttpResponseException
|
|
from synapse.appservice import (
|
|
ApplicationService,
|
|
TransactionOneTimeKeysCount,
|
|
TransactionUnusedFallbackKeys,
|
|
)
|
|
from synapse.events import EventBase
|
|
from synapse.events.utils import SerializeEventConfig, serialize_event
|
|
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
|
|
from synapse.logging import opentracing
|
|
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
|
|
from synapse.util.caches.response_cache import ResponseCache
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
sent_transactions_counter = Counter(
|
|
"synapse_appservice_api_sent_transactions",
|
|
"Number of /transactions/ requests sent",
|
|
["service"],
|
|
)
|
|
|
|
failed_transactions_counter = Counter(
|
|
"synapse_appservice_api_failed_transactions",
|
|
"Number of /transactions/ requests that failed to send",
|
|
["service"],
|
|
)
|
|
|
|
sent_events_counter = Counter(
|
|
"synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
|
|
)
|
|
|
|
sent_ephemeral_counter = Counter(
|
|
"synapse_appservice_api_sent_ephemeral",
|
|
"Number of ephemeral events sent to the AS",
|
|
["service"],
|
|
)
|
|
|
|
sent_todevice_counter = Counter(
|
|
"synapse_appservice_api_sent_todevice",
|
|
"Number of todevice messages sent to the AS",
|
|
["service"],
|
|
)
|
|
|
|
HOUR_IN_MS = 60 * 60 * 1000
|
|
|
|
APP_SERVICE_PREFIX = "/_matrix/app/v1"
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
|
|
if "instances" not in info:
|
|
return False
|
|
if not isinstance(info["instances"], list):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _is_valid_3pe_result(r: object, field: str) -> TypeGuard[JsonDict]:
|
|
if not isinstance(r, dict):
|
|
return False
|
|
|
|
for k in (field, "protocol"):
|
|
if k not in r:
|
|
return False
|
|
if not isinstance(r[k], str):
|
|
return False
|
|
|
|
if "fields" not in r:
|
|
return False
|
|
fields = r["fields"]
|
|
if not isinstance(fields, dict):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class ApplicationServiceApi(SimpleHttpClient):
|
|
"""This class manages HS -> AS communications, including querying and
|
|
pushing.
|
|
"""
|
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
super().__init__(hs)
|
|
self.clock = hs.get_clock()
|
|
self.config = hs.config.appservice
|
|
|
|
self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
|
|
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
|
)
|
|
|
|
def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
|
|
"""This makes sure we have always the auth header and opentracing headers set."""
|
|
|
|
# This is also ensured before in the functions. However this is needed to please
|
|
# the typechecks.
|
|
assert service.hs_token is not None
|
|
|
|
headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]}
|
|
opentracing.inject_header_dict(headers, check_destination=False)
|
|
return headers
|
|
|
|
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
|
|
if service.url is None:
|
|
return False
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
try:
|
|
args = None
|
|
if self.config.use_appservice_legacy_authorization:
|
|
args = {"access_token": service.hs_token}
|
|
|
|
response = await self.get_json(
|
|
f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
|
|
args,
|
|
headers=self._get_headers(service),
|
|
)
|
|
if response is not None: # just an empty json object
|
|
return True
|
|
except CodeMessageException as e:
|
|
if e.code == 404:
|
|
return False
|
|
logger.warning("query_user to %s received %s", service.url, e.code)
|
|
except Exception as ex:
|
|
logger.warning("query_user to %s threw exception %s", service.url, ex)
|
|
return False
|
|
|
|
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
|
|
if service.url is None:
|
|
return False
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
try:
|
|
args = None
|
|
if self.config.use_appservice_legacy_authorization:
|
|
args = {"access_token": service.hs_token}
|
|
|
|
response = await self.get_json(
|
|
f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
|
|
args,
|
|
headers=self._get_headers(service),
|
|
)
|
|
if response is not None: # just an empty json object
|
|
return True
|
|
except CodeMessageException as e:
|
|
logger.warning("query_alias to %s received %s", service.url, e.code)
|
|
if e.code == 404:
|
|
return False
|
|
except Exception as ex:
|
|
logger.warning("query_alias to %s threw exception %s", service.url, ex)
|
|
return False
|
|
|
|
async def query_3pe(
|
|
self,
|
|
service: "ApplicationService",
|
|
kind: str,
|
|
protocol: str,
|
|
fields: Dict[bytes, List[bytes]],
|
|
) -> List[JsonDict]:
|
|
if kind == ThirdPartyEntityKind.USER:
|
|
required_field = "userid"
|
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
|
required_field = "alias"
|
|
else:
|
|
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
|
|
if service.url is None:
|
|
return []
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
try:
|
|
args: Mapping[bytes, Union[List[bytes], str]] = fields
|
|
if self.config.use_appservice_legacy_authorization:
|
|
args = {
|
|
**fields,
|
|
b"access_token": service.hs_token,
|
|
}
|
|
|
|
response = await self.get_json(
|
|
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
|
|
args=args,
|
|
headers=self._get_headers(service),
|
|
)
|
|
if not isinstance(response, list):
|
|
logger.warning(
|
|
"query_3pe to %s returned an invalid response %r",
|
|
service.url,
|
|
response,
|
|
)
|
|
return []
|
|
|
|
ret = []
|
|
for r in response:
|
|
if _is_valid_3pe_result(r, field=required_field):
|
|
ret.append(r)
|
|
else:
|
|
logger.warning(
|
|
"query_3pe to %s returned an invalid result %r", service.url, r
|
|
)
|
|
|
|
return ret
|
|
except Exception as ex:
|
|
logger.warning("query_3pe to %s threw exception %s", service.url, ex)
|
|
return []
|
|
|
|
async def get_3pe_protocol(
|
|
self, service: "ApplicationService", protocol: str
|
|
) -> Optional[JsonDict]:
|
|
if service.url is None:
|
|
return {}
|
|
|
|
async def _get() -> Optional[JsonDict]:
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
try:
|
|
args = None
|
|
if self.config.use_appservice_legacy_authorization:
|
|
args = {"access_token": service.hs_token}
|
|
|
|
info = await self.get_json(
|
|
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
|
|
args,
|
|
headers=self._get_headers(service),
|
|
)
|
|
|
|
if not _is_valid_3pe_metadata(info):
|
|
logger.warning(
|
|
"query_3pe_protocol to %s did not return a valid result",
|
|
service.url,
|
|
)
|
|
return None
|
|
|
|
for instance in info.get("instances", []):
|
|
network_id = instance.get("network_id", None)
|
|
if network_id is not None:
|
|
instance["instance_id"] = ThirdPartyInstanceID(
|
|
service.id, network_id
|
|
).to_string()
|
|
|
|
return info
|
|
except Exception as ex:
|
|
logger.warning(
|
|
"query_3pe_protocol to %s threw exception %s", service.url, ex
|
|
)
|
|
return None
|
|
|
|
key = (service.id, protocol)
|
|
return await self.protocol_meta_cache.wrap(key, _get)
|
|
|
|
async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None:
|
|
# The caller should check that url is set
|
|
assert service.url is not None, "ping called without URL being set"
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
await self.post_json_get_json(
|
|
uri=f"{service.url}{APP_SERVICE_PREFIX}/ping",
|
|
post_json={"transaction_id": txn_id},
|
|
headers=self._get_headers(service),
|
|
)
|
|
|
|
async def push_bulk(
|
|
self,
|
|
service: "ApplicationService",
|
|
events: Sequence[EventBase],
|
|
ephemeral: List[JsonMapping],
|
|
to_device_messages: List[JsonMapping],
|
|
one_time_keys_count: TransactionOneTimeKeysCount,
|
|
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
|
device_list_summary: DeviceListUpdates,
|
|
txn_id: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
Push data to an application service.
|
|
|
|
Args:
|
|
service: The application service to send to.
|
|
events: The persistent events to send.
|
|
ephemeral: The ephemeral events to send.
|
|
to_device_messages: The to-device messages to send.
|
|
txn_id: An unique ID to assign to this transaction. Application services should
|
|
deduplicate transactions received with identitical IDs.
|
|
|
|
Returns:
|
|
True if the task succeeded, False if it failed.
|
|
"""
|
|
if service.url is None:
|
|
return True
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
serialized_events = self._serialize(service, events)
|
|
|
|
if txn_id is None:
|
|
logger.warning(
|
|
"push_bulk: Missing txn ID sending events to %s", service.url
|
|
)
|
|
txn_id = 0
|
|
|
|
# Never send ephemeral events to appservices that do not support it
|
|
body: JsonDict = {"events": serialized_events}
|
|
if service.supports_ephemeral:
|
|
body.update(
|
|
{
|
|
# TODO: Update to stable prefixes once MSC2409 completes FCP merge.
|
|
"de.sorunome.msc2409.ephemeral": ephemeral,
|
|
"de.sorunome.msc2409.to_device": to_device_messages,
|
|
}
|
|
)
|
|
|
|
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
|
|
if service.msc3202_transaction_extensions:
|
|
if one_time_keys_count:
|
|
body[
|
|
"org.matrix.msc3202.device_one_time_key_counts"
|
|
] = one_time_keys_count
|
|
body[
|
|
"org.matrix.msc3202.device_one_time_keys_count"
|
|
] = one_time_keys_count
|
|
if unused_fallback_keys:
|
|
body[
|
|
"org.matrix.msc3202.device_unused_fallback_key_types"
|
|
] = unused_fallback_keys
|
|
if device_list_summary:
|
|
body["org.matrix.msc3202.device_lists"] = {
|
|
"changed": list(device_list_summary.changed),
|
|
"left": list(device_list_summary.left),
|
|
}
|
|
|
|
try:
|
|
args = None
|
|
if self.config.use_appservice_legacy_authorization:
|
|
args = {"access_token": service.hs_token}
|
|
|
|
await self.put_json(
|
|
f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
|
|
json_body=body,
|
|
args=args,
|
|
headers=self._get_headers(service),
|
|
)
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logger.debug(
|
|
"push_bulk to %s succeeded! events=%s",
|
|
service.url,
|
|
[event.get("event_id") for event in events],
|
|
)
|
|
sent_transactions_counter.labels(service.id).inc()
|
|
sent_events_counter.labels(service.id).inc(len(serialized_events))
|
|
sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
|
|
sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
|
|
return True
|
|
except CodeMessageException as e:
|
|
logger.warning(
|
|
"push_bulk to %s received code=%s msg=%s",
|
|
service.url,
|
|
e.code,
|
|
e.msg,
|
|
exc_info=logger.isEnabledFor(logging.DEBUG),
|
|
)
|
|
except Exception as ex:
|
|
logger.warning(
|
|
"push_bulk to %s threw exception(%s) %s args=%s",
|
|
service.url,
|
|
type(ex).__name__,
|
|
ex,
|
|
ex.args,
|
|
exc_info=logger.isEnabledFor(logging.DEBUG),
|
|
)
|
|
failed_transactions_counter.labels(service.id).inc()
|
|
return False
|
|
|
|
async def claim_client_keys(
|
|
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
|
|
) -> Tuple[
|
|
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
|
]:
|
|
"""Claim one time keys from an application service.
|
|
|
|
Note that any error (including a timeout) is treated as the application
|
|
service having no information.
|
|
|
|
Args:
|
|
service: The application service to query.
|
|
query: An iterable of tuples of (user ID, device ID, algorithm).
|
|
|
|
Returns:
|
|
A tuple of:
|
|
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
|
|
|
|
A copy of the input which has not been fulfilled because the
|
|
appservice doesn't support this endpoint or has not returned
|
|
data for that tuple.
|
|
"""
|
|
if service.url is None:
|
|
return {}, query
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
# Create the expected payload shape.
|
|
body: Dict[str, Dict[str, List[str]]] = {}
|
|
for user_id, device, algorithm, count in query:
|
|
body.setdefault(user_id, {}).setdefault(device, []).extend(
|
|
[algorithm] * count
|
|
)
|
|
|
|
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
|
|
try:
|
|
response = await self.post_json_get_json(
|
|
uri,
|
|
body,
|
|
headers=self._get_headers(service),
|
|
)
|
|
except HttpResponseException as e:
|
|
# The appservice doesn't support this endpoint.
|
|
if is_unknown_endpoint(e):
|
|
return {}, query
|
|
logger.warning("claim_keys to %s received %s", uri, e.code)
|
|
return {}, query
|
|
except Exception as ex:
|
|
logger.warning("claim_keys to %s threw exception %s", uri, ex)
|
|
return {}, query
|
|
|
|
# Check if the appservice fulfilled all of the queried user/device/algorithms
|
|
# or if some are still missing.
|
|
#
|
|
# TODO This places a lot of faith in the response shape being correct.
|
|
missing = []
|
|
for user_id, device, algorithm, count in query:
|
|
# Count the number of keys in the response for this algorithm by
|
|
# checking which key IDs start with the algorithm. This uses that
|
|
# True == 1 in Python to generate a count.
|
|
response_count = sum(
|
|
key_id.startswith(f"{algorithm}:")
|
|
for key_id in response.get(user_id, {}).get(device, {})
|
|
)
|
|
count -= response_count
|
|
# If the appservice responds with fewer keys than requested, then
|
|
# consider the request unfulfilled.
|
|
if count > 0:
|
|
missing.append((user_id, device, algorithm, count))
|
|
|
|
return response, missing
|
|
|
|
async def query_keys(
|
|
self, service: "ApplicationService", query: Dict[str, List[str]]
|
|
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
|
"""Query the application service for keys.
|
|
|
|
Note that any error (including a timeout) is treated as the application
|
|
service having no information.
|
|
|
|
Args:
|
|
service: The application service to query.
|
|
query: An iterable of tuples of (user ID, device ID, algorithm).
|
|
|
|
Returns:
|
|
A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
|
|
|
|
device_keys is a map of user ID -> a map device ID -> device info.
|
|
"""
|
|
if service.url is None:
|
|
return {}
|
|
|
|
# This is required by the configuration.
|
|
assert service.hs_token is not None
|
|
|
|
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
|
|
try:
|
|
response = await self.post_json_get_json(
|
|
uri,
|
|
query,
|
|
headers=self._get_headers(service),
|
|
)
|
|
except HttpResponseException as e:
|
|
# The appservice doesn't support this endpoint.
|
|
if is_unknown_endpoint(e):
|
|
return {}
|
|
logger.warning("query_keys to %s received %s", uri, e.code)
|
|
return {}
|
|
except Exception as ex:
|
|
logger.warning("query_keys to %s threw exception %s", uri, ex)
|
|
return {}
|
|
|
|
return response
|
|
|
|
def _serialize(
|
|
self, service: "ApplicationService", events: Iterable[EventBase]
|
|
) -> List[JsonDict]:
|
|
time_now = self.clock.time_msec()
|
|
return [
|
|
serialize_event(
|
|
e,
|
|
time_now,
|
|
config=SerializeEventConfig(
|
|
as_client_event=True,
|
|
# If this is an invite or a knock membership event, and we're interested
|
|
# in this user, then include any stripped state alongside the event.
|
|
include_stripped_room_state=(
|
|
e.type == EventTypes.Member
|
|
and (
|
|
e.membership == Membership.INVITE
|
|
or e.membership == Membership.KNOCK
|
|
)
|
|
and service.is_interested_in_user(e.state_key)
|
|
),
|
|
),
|
|
)
|
|
for e in events
|
|
]
|