Revert "Revert accidental fast-forward merge from v1.49.0rc1"

This reverts commit 158d73ebdd.
This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-12-14 14:22:01 +00:00
parent 158d73ebdd
commit 4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions

View file

@ -15,10 +15,13 @@
import functools
import logging
import re
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name
@ -59,9 +63,11 @@ class Authenticator:
self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content):
async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict]
) -> str:
now = self._clock.time_msec()
json_request = {
json_request: JsonDict = {
"method": request.method.decode("ascii"),
"uri": request.uri.decode("ascii"),
"destination": self.server_name,
@ -114,7 +120,7 @@ class Authenticator:
return origin
async def _reset_retry_timings(self, origin):
async def _reset_retry_timings(self, origin: str) -> None:
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
@ -133,14 +139,14 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
header_bytes: header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
try:
header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
def strip_quotes(value):
def strip_quotes(value: str) -> str:
if value.startswith('"'):
return value[1:-1]
else:
@ -233,23 +239,25 @@ class BaseFederationServlet:
self.ratelimiter = ratelimiter
self.server_name = server_name
def _wrap(self, func):
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@functools.wraps(func)
async def new_func(request, *args, **kwargs):
async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[Tuple[int, Any]]:
"""A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
request:
*args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
**kwargs: the dict mapping keys to path components as specified
in the path match regexp.
Returns:
Tuple[int, object]|None: (response code, response object) as returned by
the callback method. None if the request has already been handled.
(response code, response object) as returned by the callback method.
None if the request has already been handled.
"""
content = None
if request.method in [b"PUT", b"POST"]:
@ -257,7 +265,9 @@ class BaseFederationServlet:
content = parse_json_object_from_request(request)
try:
origin = await authenticator.authenticate_request(request, content)
origin: Optional[str] = await authenticator.authenticate_request(
request, content
)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
@ -301,7 +311,7 @@ class BaseFederationServlet:
"client disconnected before we started processing "
"request"
)
return -1, None
return None
response = await func(
origin, content, request.args, *args, **kwargs
)
@ -312,9 +322,9 @@ class BaseFederationServlet:
return response
return new_func
return cast(ServletCallback, new_func)
def register(self, server):
def register(self, server: HttpServer) -> None:
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):