mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
d8cc86eff4
Remove type hints from comments which have been added as Python type hints. This helps avoid drift between comments and reality, as well as removing redundant information. Also adds some missing type hints which were simple to fill in.
396 lines
14 KiB
Python
396 lines
14 KiB
Python
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
import logging
|
|
import re
|
|
import time
|
|
from http import HTTPStatus
|
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
|
|
|
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
|
from synapse.api.urls import FEDERATION_V1_PREFIX
|
|
from synapse.http.server import HttpServer, ServletCallback
|
|
from synapse.http.servlet import parse_json_object_from_request
|
|
from synapse.http.site import SynapseRequest
|
|
from synapse.logging.context import run_in_background
|
|
from synapse.logging.opentracing import (
|
|
active_span,
|
|
set_tag,
|
|
span_context_from_request,
|
|
start_active_span,
|
|
start_active_span_follows_from,
|
|
whitelisted_homeserver,
|
|
)
|
|
from synapse.types import JsonDict
|
|
from synapse.util.cancellation import is_function_cancellable
|
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
from synapse.util.stringutils import parse_and_validate_server_name
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AuthenticationError(SynapseError):
|
|
"""There was a problem authenticating the request"""
|
|
|
|
|
|
class NoAuthenticationError(AuthenticationError):
|
|
"""The request had no authentication information"""
|
|
|
|
|
|
class Authenticator:
|
|
def __init__(self, hs: "HomeServer"):
|
|
self._clock = hs.get_clock()
|
|
self.keyring = hs.get_keyring()
|
|
self.server_name = hs.hostname
|
|
self.store = hs.get_datastores().main
|
|
self.federation_domain_whitelist = (
|
|
hs.config.federation.federation_domain_whitelist
|
|
)
|
|
self.notifier = hs.get_notifier()
|
|
|
|
self.replication_client = None
|
|
if hs.config.worker.worker_app:
|
|
self.replication_client = hs.get_replication_command_handler()
|
|
|
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
|
async def authenticate_request(
|
|
self, request: SynapseRequest, content: Optional[JsonDict]
|
|
) -> str:
|
|
now = self._clock.time_msec()
|
|
json_request: JsonDict = {
|
|
"method": request.method.decode("ascii"),
|
|
"uri": request.uri.decode("ascii"),
|
|
"destination": self.server_name,
|
|
"signatures": {},
|
|
}
|
|
|
|
if content is not None:
|
|
json_request["content"] = content
|
|
|
|
origin = None
|
|
|
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
|
|
|
if not auth_headers:
|
|
raise NoAuthenticationError(
|
|
HTTPStatus.UNAUTHORIZED,
|
|
"Missing Authorization headers",
|
|
Codes.UNAUTHORIZED,
|
|
)
|
|
|
|
for auth in auth_headers:
|
|
if auth.startswith(b"X-Matrix"):
|
|
(origin, key, sig, destination) = _parse_auth_header(auth)
|
|
json_request["origin"] = origin
|
|
json_request["signatures"].setdefault(origin, {})[key] = sig
|
|
|
|
# if the origin_server sent a destination along it needs to match our own server_name
|
|
if destination is not None and destination != self.server_name:
|
|
raise AuthenticationError(
|
|
HTTPStatus.UNAUTHORIZED,
|
|
"Destination mismatch in auth header",
|
|
Codes.UNAUTHORIZED,
|
|
)
|
|
if (
|
|
self.federation_domain_whitelist is not None
|
|
and origin not in self.federation_domain_whitelist
|
|
):
|
|
raise FederationDeniedError(origin)
|
|
|
|
if origin is None or not json_request["signatures"]:
|
|
raise NoAuthenticationError(
|
|
HTTPStatus.UNAUTHORIZED,
|
|
"Missing Authorization headers",
|
|
Codes.UNAUTHORIZED,
|
|
)
|
|
|
|
await self.keyring.verify_json_for_server(
|
|
origin,
|
|
json_request,
|
|
now,
|
|
)
|
|
|
|
logger.debug("Request from %s", origin)
|
|
request.requester = origin
|
|
|
|
# If we get a valid signed request from the other side, its probably
|
|
# alive
|
|
retry_timings = await self.store.get_destination_retry_timings(origin)
|
|
if retry_timings and retry_timings.retry_last_ts:
|
|
run_in_background(self.reset_retry_timings, origin)
|
|
|
|
return origin
|
|
|
|
async def reset_retry_timings(self, origin: str) -> None:
|
|
try:
|
|
logger.info("Marking origin %r as up", origin)
|
|
await self.store.set_destination_retry_timings(origin, None, 0, 0)
|
|
|
|
# Inform the relevant places that the remote server is back up.
|
|
self.notifier.notify_remote_server_up(origin)
|
|
if self.replication_client:
|
|
# If we're on a worker we try and inform master about this. The
|
|
# replication client doesn't hook into the notifier to avoid
|
|
# infinite loops where we send a `REMOTE_SERVER_UP` command to
|
|
# master, which then echoes it back to us which in turn pokes
|
|
# the notifier.
|
|
self.replication_client.send_remote_server_up(origin)
|
|
|
|
except Exception:
|
|
logger.exception("Error resetting retry timings on %s", origin)
|
|
|
|
|
|
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str]]:
|
|
"""Parse an X-Matrix auth header
|
|
|
|
Args:
|
|
header_bytes: header value
|
|
|
|
Returns:
|
|
origin, key id, signature, destination.
|
|
origin, key id, signature.
|
|
|
|
Raises:
|
|
AuthenticationError if the header could not be parsed
|
|
"""
|
|
try:
|
|
header_str = header_bytes.decode("utf-8")
|
|
params = re.split(" +", header_str)[1].split(",")
|
|
param_dict: Dict[str, str] = {
|
|
k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
|
|
}
|
|
|
|
def strip_quotes(value: str) -> str:
|
|
if value.startswith('"'):
|
|
return re.sub(
|
|
"\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
|
|
)
|
|
else:
|
|
return value
|
|
|
|
origin = strip_quotes(param_dict["origin"])
|
|
|
|
# ensure that the origin is a valid server name
|
|
parse_and_validate_server_name(origin)
|
|
|
|
key = strip_quotes(param_dict["key"])
|
|
sig = strip_quotes(param_dict["sig"])
|
|
|
|
# get the destination server_name from the auth header if it exists
|
|
destination = param_dict.get("destination")
|
|
if destination is not None:
|
|
destination = strip_quotes(destination)
|
|
else:
|
|
destination = None
|
|
|
|
return origin, key, sig, destination
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Error parsing auth header '%s': %s",
|
|
header_bytes.decode("ascii", "replace"),
|
|
e,
|
|
)
|
|
raise AuthenticationError(
|
|
HTTPStatus.BAD_REQUEST, "Malformed Authorization header", Codes.UNAUTHORIZED
|
|
)
|
|
|
|
|
|
class BaseFederationServlet:
|
|
"""Abstract base class for federation servlet classes.
|
|
|
|
The servlet object should have a PATH attribute which takes the form of a regexp to
|
|
match against the request path (excluding the /federation/v1 prefix).
|
|
|
|
The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
|
|
the appropriate HTTP method. These methods must be *asynchronous* and have the
|
|
signature:
|
|
|
|
on_<METHOD>(self, origin, content, query, **kwargs)
|
|
|
|
With arguments:
|
|
|
|
origin (str|None): The authenticated server_name of the calling server,
|
|
unless REQUIRE_AUTH is set to False and authentication failed.
|
|
|
|
content (str|None): decoded json body of the request. None if the
|
|
request was a GET.
|
|
|
|
query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
|
|
(ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
|
|
yet.
|
|
|
|
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
|
|
components as specified in the path match regexp.
|
|
|
|
Returns:
|
|
Optional[Tuple[int, object]]: either (response code, response object) to
|
|
return a JSON response, or None if the request has already been handled.
|
|
|
|
Raises:
|
|
SynapseError: to return an error code
|
|
|
|
Exception: other exceptions will be caught, logged, and a 500 will be
|
|
returned.
|
|
"""
|
|
|
|
PATH = "" # Overridden in subclasses, the regex to match against the path.
|
|
|
|
REQUIRE_AUTH = True
|
|
|
|
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
|
|
|
|
RATELIMIT = True # Whether to rate limit requests or not
|
|
|
|
def __init__(
|
|
self,
|
|
hs: "HomeServer",
|
|
authenticator: Authenticator,
|
|
ratelimiter: FederationRateLimiter,
|
|
server_name: str,
|
|
):
|
|
self.hs = hs
|
|
self.authenticator = authenticator
|
|
self.ratelimiter = ratelimiter
|
|
self.server_name = server_name
|
|
|
|
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
|
|
authenticator = self.authenticator
|
|
ratelimiter = self.ratelimiter
|
|
|
|
@functools.wraps(func)
|
|
async def new_func(
|
|
request: SynapseRequest, *args: Any, **kwargs: str
|
|
) -> Optional[Tuple[int, Any]]:
|
|
"""A callback which can be passed to HttpServer.RegisterPaths
|
|
|
|
Args:
|
|
request:
|
|
*args: unused?
|
|
**kwargs: the dict mapping keys to path components as specified
|
|
in the path match regexp.
|
|
|
|
Returns:
|
|
(response code, response object) as returned by the callback method.
|
|
None if the request has already been handled.
|
|
"""
|
|
content = None
|
|
if request.method in [b"PUT", b"POST"]:
|
|
# TODO: Handle other method types? other content types?
|
|
content = parse_json_object_from_request(request)
|
|
|
|
try:
|
|
with start_active_span("authenticate_request"):
|
|
origin: Optional[str] = await authenticator.authenticate_request(
|
|
request, content
|
|
)
|
|
except NoAuthenticationError:
|
|
origin = None
|
|
if self.REQUIRE_AUTH:
|
|
logger.warning(
|
|
"authenticate_request failed: missing authentication"
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
logger.warning("authenticate_request failed: %s", e)
|
|
raise
|
|
|
|
# update the active opentracing span with the authenticated entity
|
|
set_tag("authenticated_entity", str(origin))
|
|
|
|
# if the origin is authenticated and whitelisted, use its span context
|
|
# as the parent.
|
|
context = None
|
|
if origin and whitelisted_homeserver(origin):
|
|
context = span_context_from_request(request)
|
|
|
|
if context:
|
|
servlet_span = active_span()
|
|
# a scope which uses the origin's context as a parent
|
|
processing_start_time = time.time()
|
|
scope = start_active_span_follows_from(
|
|
"incoming-federation-request",
|
|
child_of=context,
|
|
contexts=(servlet_span,),
|
|
start_time=processing_start_time,
|
|
)
|
|
|
|
else:
|
|
# just use our context as a parent
|
|
scope = start_active_span(
|
|
"incoming-federation-request",
|
|
)
|
|
|
|
try:
|
|
with scope:
|
|
if origin and self.RATELIMIT:
|
|
with ratelimiter.ratelimit(origin) as d:
|
|
await d
|
|
if request._disconnected:
|
|
logger.warning(
|
|
"client disconnected before we started processing "
|
|
"request"
|
|
)
|
|
return None
|
|
response = await func(
|
|
origin, content, request.args, *args, **kwargs
|
|
)
|
|
else:
|
|
response = await func(
|
|
origin, content, request.args, *args, **kwargs
|
|
)
|
|
finally:
|
|
# if we used the origin's context as the parent, add a new span using
|
|
# the servlet span as a parent, so that we have a link
|
|
if context:
|
|
scope2 = start_active_span_follows_from(
|
|
"process-federation_request",
|
|
contexts=(scope.span,),
|
|
start_time=processing_start_time,
|
|
)
|
|
scope2.close()
|
|
|
|
return response
|
|
|
|
return cast(ServletCallback, new_func)
|
|
|
|
def register(self, server: HttpServer) -> None:
|
|
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
|
|
|
|
for method in ("GET", "PUT", "POST"):
|
|
code = getattr(self, "on_%s" % (method), None)
|
|
if code is None:
|
|
continue
|
|
|
|
if is_function_cancellable(code):
|
|
# The wrapper added by `self._wrap` will inherit the cancellable flag,
|
|
# but the wrapper itself does not support cancellation yet.
|
|
# Once resolved, the cancellation tests in
|
|
# `tests/federation/transport/server/test__base.py` can be re-enabled.
|
|
raise Exception(
|
|
f"{self.__class__.__name__}.on_{method} has been marked as "
|
|
"cancellable, but federation servlets do not support cancellation "
|
|
"yet."
|
|
)
|
|
|
|
server.register_paths(
|
|
method,
|
|
(pattern,),
|
|
self._wrap(code),
|
|
self.__class__.__name__,
|
|
)
|