Add most of the missing type hints to synapse.federation. (#11483)

This skips a few methods which are difficult to type.
This commit is contained in:
Patrick Cloke 2021-12-02 11:18:10 -05:00 committed by GitHub
parent b50e39df57
commit d2279f471b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 49 deletions

View file

@ -15,10 +15,13 @@
import functools
import logging
import re
from typing import Any, Awaitable, Callable, Optional, Tuple, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import parse_and_validate_server_name
@ -59,9 +63,11 @@ class Authenticator:
self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content):
async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict]
) -> str:
now = self._clock.time_msec()
json_request = {
json_request: JsonDict = {
"method": request.method.decode("ascii"),
"uri": request.uri.decode("ascii"),
"destination": self.server_name,
@ -114,7 +120,7 @@ class Authenticator:
return origin
async def _reset_retry_timings(self, origin):
async def _reset_retry_timings(self, origin: str) -> None:
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
@ -133,14 +139,14 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
header_bytes: header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
try:
header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
def strip_quotes(value):
def strip_quotes(value: str) -> str:
if value.startswith('"'):
return value[1:-1]
else:
@ -233,23 +239,25 @@ class BaseFederationServlet:
self.ratelimiter = ratelimiter
self.server_name = server_name
def _wrap(self, func):
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@functools.wraps(func)
async def new_func(request, *args, **kwargs):
async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[Tuple[int, Any]]:
"""A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
request:
*args: unused?
**kwargs (dict[unicode, unicode]): the dict mapping keys to path
components as specified in the path match regexp.
**kwargs: the dict mapping keys to path components as specified
in the path match regexp.
Returns:
Tuple[int, object]|None: (response code, response object) as returned by
the callback method. None if the request has already been handled.
(response code, response object) as returned by the callback method.
None if the request has already been handled.
"""
content = None
if request.method in [b"PUT", b"POST"]:
@ -257,7 +265,9 @@ class BaseFederationServlet:
content = parse_json_object_from_request(request)
try:
origin = await authenticator.authenticate_request(request, content)
origin: Optional[str] = await authenticator.authenticate_request(
request, content
)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
@ -301,7 +311,7 @@ class BaseFederationServlet:
"client disconnected before we started processing "
"request"
)
return -1, None
return None
response = await func(
origin, content, request.args, *args, **kwargs
)
@ -312,9 +322,9 @@ class BaseFederationServlet:
return response
return new_func
return cast(ServletCallback, new_func)
def register(self, server):
def register(self, server: HttpServer) -> None:
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):