# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import logging import re from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.http.servlet import parse_json_object_from_request from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( SynapseTags, start_active_span, start_active_span_from_request, tags, whitelisted_homeserver, ) from synapse.server import HomeServer from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name logger = logging.getLogger(__name__) class AuthenticationError(SynapseError): """There was a problem authenticating the request""" class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" class Authenticator: def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) self.notifier = hs.get_notifier() self.replication_client = None if hs.config.worker.worker_app: self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { "method": request.method.decode("ascii"), "uri": request.uri.decode("ascii"), "destination": self.server_name, "signatures": {}, } if content is not None: json_request["content"] = content origin = None auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) for auth in auth_headers: if auth.startswith(b"X-Matrix"): (origin, key, sig) = _parse_auth_header(auth) json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig if ( self.federation_domain_whitelist is not None and origin not in self.federation_domain_whitelist ): raise FederationDeniedError(origin) if origin is None or not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) await self.keyring.verify_json_for_server( origin, json_request, now, ) logger.debug("Request from %s", origin) request.requester = origin # If we get a valid signed request from the other side, its probably # alive retry_timings = await self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings.retry_last_ts: run_in_background(self._reset_retry_timings, origin) return origin async def _reset_retry_timings(self, origin): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) # Inform the relevant places that the remote server is back up. self.notifier.notify_remote_server_up(origin) if self.replication_client: # If we're on a worker we try and inform master about this. The # replication client doesn't hook into the notifier to avoid # infinite loops where we send a `REMOTE_SERVER_UP` command to # master, which then echoes it back to us which in turn pokes # the notifier. self.replication_client.send_remote_server_up(origin) except Exception: logger.exception("Error resetting retry timings on %s", origin) def _parse_auth_header(header_bytes): """Parse an X-Matrix auth header Args: header_bytes (bytes): header value Returns: Tuple[str, str, str]: origin, key id, signature. Raises: AuthenticationError if the header could not be parsed """ try: header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): if value.startswith('"'): return value[1:-1] else: return value origin = strip_quotes(param_dict["origin"]) # ensure that the origin is a valid server name parse_and_validate_server_name(origin) key = strip_quotes(param_dict["key"]) sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, ) raise AuthenticationError( 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) class BaseFederationServlet: """Abstract base class for federation servlet classes. The servlet object should have a PATH attribute which takes the form of a regexp to match against the request path (excluding the /federation/v1 prefix). The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match the appropriate HTTP method. These methods must be *asynchronous* and have the signature: on_(self, origin, content, query, **kwargs) With arguments: origin (unicode|None): The authenticated server_name of the calling server, unless REQUIRE_AUTH is set to False and authentication failed. content (unicode|None): decoded json body of the request. None if the request was a GET. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded yet. **kwargs (dict[unicode, unicode]): the dict mapping keys to path components as specified in the path match regexp. Returns: Optional[Tuple[int, object]]: either (response code, response object) to return a JSON response, or None if the request has already been handled. Raises: SynapseError: to return an error code Exception: other exceptions will be caught, logged, and a 500 will be returned. """ PATH = "" # Overridden in subclasses, the regex to match against the path. REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version RATELIMIT = True # Whether to rate limit requests or not def __init__( self, hs: HomeServer, authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, ): self.hs = hs self.authenticator = authenticator self.ratelimiter = ratelimiter self.server_name = server_name def _wrap(self, func): authenticator = self.authenticator ratelimiter = self.ratelimiter @functools.wraps(func) async def new_func(request, *args, **kwargs): """A callback which can be passed to HttpServer.RegisterPaths Args: request (twisted.web.http.Request): *args: unused? **kwargs (dict[unicode, unicode]): the dict mapping keys to path components as specified in the path match regexp. Returns: Tuple[int, object]|None: (response code, response object) as returned by the callback method. None if the request has already been handled. """ content = None if request.method in [b"PUT", b"POST"]: # TODO: Handle other method types? other content types? content = parse_json_object_from_request(request) try: origin = await authenticator.authenticate_request(request, content) except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: logger.warning( "authenticate_request failed: missing authentication" ) raise except Exception as e: logger.warning("authenticate_request failed: %s", e) raise request_tags = { SynapseTags.REQUEST_ID: request.get_request_id(), tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.HTTP_METHOD: request.get_method(), tags.HTTP_URL: request.get_redacted_uri(), tags.PEER_HOST_IPV6: request.getClientIP(), "authenticated_entity": origin, "servlet_name": request.request_metrics.name, } # Only accept the span context if the origin is authenticated # and whitelisted if origin and whitelisted_homeserver(origin): scope = start_active_span_from_request( request, "incoming-federation-request", tags=request_tags ) else: scope = start_active_span( "incoming-federation-request", tags=request_tags ) with scope: opentracing.inject_response_headers(request.responseHeaders) if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d if request._disconnected: logger.warning( "client disconnected before we started processing " "request" ) return -1, None response = await func( origin, content, request.args, *args, **kwargs ) else: response = await func( origin, content, request.args, *args, **kwargs ) return response return new_func def register(self, server): pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): code = getattr(self, "on_%s" % (method), None) if code is None: continue server.register_paths( method, (pattern,), self._wrap(code), self.__class__.__name__, )