forked-synapse/synapse/appservice/api.py

555 lines
20 KiB
Python

# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib.parse
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Counter
from typing_extensions import ParamSpec, TypeGuard
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException, HttpResponseException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
"synapse_appservice_api_sent_transactions",
"Number of /transactions/ requests sent",
["service"],
)
failed_transactions_counter = Counter(
"synapse_appservice_api_failed_transactions",
"Number of /transactions/ requests that failed to send",
["service"],
)
sent_events_counter = Counter(
"synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
)
sent_ephemeral_counter = Counter(
"synapse_appservice_api_sent_ephemeral",
"Number of ephemeral events sent to the AS",
["service"],
)
sent_todevice_counter = Counter(
"synapse_appservice_api_sent_todevice",
"Number of todevice messages sent to the AS",
["service"],
)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/v1"
P = ParamSpec("P")
R = TypeVar("R")
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r: object, field: str) -> TypeGuard[JsonDict]:
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
return True
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.config = hs.config.appservice
self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
"""This makes sure we have always the auth header and opentracing headers set."""
# This is also ensured before in the functions. However this is needed to please
# the typechecks.
assert service.hs_token is not None
headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]}
opentracing.inject_header_dict(headers, check_destination=False)
return headers
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False
# This is required by the configuration.
assert service.hs_token is not None
try:
args = None
if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token}
response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
args,
headers=self._get_headers(service),
)
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
if e.code == 404:
return False
logger.warning("query_user to %s received %s", service.url, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", service.url, ex)
return False
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
if service.url is None:
return False
# This is required by the configuration.
assert service.hs_token is not None
try:
args = None
if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token}
response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
args,
headers=self._get_headers(service),
)
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", service.url, e.code)
if e.code == 404:
return False
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", service.url, ex)
return False
async def query_3pe(
self,
service: "ApplicationService",
kind: str,
protocol: str,
fields: Dict[bytes, List[bytes]],
) -> List[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
return []
# This is required by the configuration.
assert service.hs_token is not None
try:
args: Mapping[bytes, Union[List[bytes], str]] = fields
if self.config.use_appservice_legacy_authorization:
args = {
**fields,
b"access_token": service.hs_token,
}
response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
args=args,
headers=self._get_headers(service),
)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
service.url,
response,
)
return []
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r", service.url, r
)
return ret
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", service.url, ex)
return []
async def get_3pe_protocol(
self, service: "ApplicationService", protocol: str
) -> Optional[JsonDict]:
if service.url is None:
return {}
async def _get() -> Optional[JsonDict]:
# This is required by the configuration.
assert service.hs_token is not None
try:
args = None
if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token}
info = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
args,
headers=self._get_headers(service),
)
if not _is_valid_3pe_metadata(info):
logger.warning(
"query_3pe_protocol to %s did not return a valid result",
service.url,
)
return None
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
service.id, network_id
).to_string()
return info
except Exception as ex:
logger.warning(
"query_3pe_protocol to %s threw exception %s", service.url, ex
)
return None
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None:
# The caller should check that url is set
assert service.url is not None, "ping called without URL being set"
# This is required by the configuration.
assert service.hs_token is not None
await self.post_json_get_json(
uri=f"{service.url}{APP_SERVICE_PREFIX}/ping",
post_json={"transaction_id": txn_id},
headers=self._get_headers(service),
)
async def push_bulk(
self,
service: "ApplicationService",
events: Sequence[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None,
) -> bool:
"""
Push data to an application service.
Args:
service: The application service to send to.
events: The persistent events to send.
ephemeral: The ephemeral events to send.
to_device_messages: The to-device messages to send.
txn_id: An unique ID to assign to this transaction. Application services should
deduplicate transactions received with identitical IDs.
Returns:
True if the task succeeded, False if it failed.
"""
if service.url is None:
return True
# This is required by the configuration.
assert service.hs_token is not None
serialized_events = self._serialize(service, events)
if txn_id is None:
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
txn_id = 0
# Never send ephemeral events to appservices that do not support it
body: JsonDict = {"events": serialized_events}
if service.supports_ephemeral:
body.update(
{
# TODO: Update to stable prefixes once MSC2409 completes FCP merge.
"de.sorunome.msc2409.ephemeral": ephemeral,
"de.sorunome.msc2409.to_device": to_device_messages,
}
)
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_keys_count:
body[
"org.matrix.msc3202.device_one_time_key_counts"
] = one_time_keys_count
body[
"org.matrix.msc3202.device_one_time_keys_count"
] = one_time_keys_count
if unused_fallback_keys:
body[
"org.matrix.msc3202.device_unused_fallback_key_types"
] = unused_fallback_keys
if device_list_summary:
body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed),
"left": list(device_list_summary.left),
}
try:
args = None
if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token}
await self.put_json(
f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
json_body=body,
args=args,
headers=self._get_headers(service),
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"push_bulk to %s succeeded! events=%s",
service.url,
[event.get("event_id") for event in events],
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(serialized_events))
sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
return True
except CodeMessageException as e:
logger.warning(
"push_bulk to %s received code=%s msg=%s",
service.url,
e.code,
e.msg,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
except Exception as ex:
logger.warning(
"push_bulk to %s threw exception(%s) %s args=%s",
service.url,
type(ex).__name__,
ex,
ex.args,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
failed_transactions_counter.labels(service.id).inc()
return False
async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from an application service.
Note that any error (including a timeout) is treated as the application
service having no information.
Args:
service: The application service to query.
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
A copy of the input which has not been fulfilled because the
appservice doesn't support this endpoint or has not returned
data for that tuple.
"""
if service.url is None:
return {}, query
# This is required by the configuration.
assert service.hs_token is not None
# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
)
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
response = await self.post_json_get_json(
uri,
body,
headers=self._get_headers(service),
)
except HttpResponseException as e:
# The appservice doesn't support this endpoint.
if is_unknown_endpoint(e):
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
except Exception as ex:
logger.warning("claim_keys to %s threw exception %s", uri, ex)
return {}, query
# Check if the appservice fulfilled all of the queried user/device/algorithms
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = []
for user_id, device, algorithm, count in query:
# Count the number of keys in the response for this algorithm by
# checking which key IDs start with the algorithm. This uses that
# True == 1 in Python to generate a count.
response_count = sum(
key_id.startswith(f"{algorithm}:")
for key_id in response.get(user_id, {}).get(device, {})
)
count -= response_count
# If the appservice responds with fewer keys than requested, then
# consider the request unfulfilled.
if count > 0:
missing.append((user_id, device, algorithm, count))
return response, missing
async def query_keys(
self, service: "ApplicationService", query: Dict[str, List[str]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Query the application service for keys.
Note that any error (including a timeout) is treated as the application
service having no information.
Args:
service: The application service to query.
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
device_keys is a map of user ID -> a map device ID -> device info.
"""
if service.url is None:
return {}
# This is required by the configuration.
assert service.hs_token is not None
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
try:
response = await self.post_json_get_json(
uri,
query,
headers=self._get_headers(service),
)
except HttpResponseException as e:
# The appservice doesn't support this endpoint.
if is_unknown_endpoint(e):
return {}
logger.warning("query_keys to %s received %s", uri, e.code)
return {}
except Exception as ex:
logger.warning("query_keys to %s threw exception %s", uri, ex)
return {}
return response
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(
e,
time_now,
config=SerializeEventConfig(
as_client_event=True,
# If this is an invite or a knock membership event, and we're interested
# in this user, then include any stripped state alongside the event.
include_stripped_room_state=(
e.type == EventTypes.Member
and (
e.membership == Membership.INVITE
or e.membership == Membership.KNOCK
)
and service.is_interested_in_user(e.state_key)
),
),
)
for e in events
]