diff --git a/changelog.d/11571.misc b/changelog.d/11571.misc new file mode 100644 index 000000000..4e396b271 --- /dev/null +++ b/changelog.d/11571.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.http`. diff --git a/mypy.ini b/mypy.ini index a7b1f4eb6..9aeeca2bb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -161,6 +161,9 @@ disallow_untyped_defs = False [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.http.server] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 578fc48ef..efecb089c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(504, msg) @@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") -def redact_uri(uri): +def redact_uri(uri: str) -> str: """Strips sensitive information from the uri replaces with """ uri = ACCESS_TOKEN_RE.sub(r"\1\3", uri) return CLIENT_SECRET_RE.sub(r"\1\3", uri) @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer): https://twistedmatrix.com/trac/ticket/6528 """ - def stopProducing(self): + def stopProducing(self) -> None: try: FileBodyProducer.stopProducing(self) except task.TaskStopped: diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 9a2684aca..6a9f6635d 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource): and exception handling. """ - def __init__(self, hs: "HomeServer", handler): + def __init__( + self, + hs: "HomeServer", + handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]], + ): """Initialise AdditionalResource The ``handler`` should return a deferred which completes when it has @@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource): super().__init__() self._handler = handler - def _async_render(self, request: Request): + async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]: # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. - return self._handler(request) + return await self._handler(request) diff --git a/synapse/http/server.py b/synapse/http/server.py index 91badb0b0..4fd5660a0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,6 +30,7 @@ from typing import ( Iterable, Iterator, List, + NoReturn, Optional, Pattern, Tuple, @@ -170,7 +171,9 @@ def return_html_error( respond_with_html(request, code, body) -def wrap_async_request_handler(h): +def wrap_async_request_handler( + h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] +) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -183,7 +186,9 @@ def wrap_async_request_handler(h): logged until the deferred completes. """ - async def wrapped_async_request_handler(self, request): + async def wrapped_async_request_handler( + self: "_AsyncResource", request: SynapseRequest + ) -> None: with request.processing(): await h(self, request) @@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): context from the request the servlet is handling. """ - def __init__(self, extract_context=False): + def __init__(self, extract_context: bool = False): super().__init__() self._extract_context = extract_context - def render(self, request): + def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest): + async def _async_render_wrapper(self, request: SynapseRequest) -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: Request): + async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource): formatting responses and errors as JSON. """ - def __init__(self, canonical_json=False, extract_context=False): + def __init__(self, canonical_json: bool = False, extract_context: bool = False): super().__init__(extract_context) self.canonical_json = canonical_json @@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( @@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): + def __init__( + self, + hs: "HomeServer", + canonical_json: bool = True, + extract_context: bool = False, + ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs - def register_paths(self, method, path_patterns, callback, servlet_classname): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate handler for that request. Args: - method (str): GET, POST etc + method: GET, POST etc - path_patterns (Iterable[str]): A list of regular expressions to which - the request URLs are compared. + path_patterns: A list of regular expressions to which the request + URLs are compared. - callback (function): The handler for the request. Usually a Servlet + callback: The handler for the request. Usually a Servlet - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ - method = method.encode("utf-8") # method is bytes on py3 + method_bytes = method.encode("utf-8") for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method, []).append( + self.path_regexs.setdefault(method_bytes, []).append( _PathEntry(path_pattern, callback, servlet_classname) ) @@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, "unrecognised_request_handler", {} - async def _async_render(self, request): + async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) # Make sure we have an appropriate name for this handler in prometheus @@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) @@ -492,12 +508,12 @@ class StaticResource(File): Differs from the File resource by adding clickjacking protection. """ - def render_GET(self, request: Request): + def render_GET(self, request: Request) -> bytes: set_clickjacking_protection_headers(request) return super().render_GET(request) -def _unrecognised_request_handler(request): +def _unrecognised_request_handler(request: Request) -> NoReturn: """Request handler for unrecognised requests This is a request handler suitable for return from @@ -505,7 +521,7 @@ def _unrecognised_request_handler(request): UnrecognizedRequestError. Args: - request (twisted.web.http.Request): + request: Unused, but passed in to match the signature of ServletCallback. """ raise UnrecognizedRequestError() @@ -513,14 +529,14 @@ def _unrecognised_request_handler(request): class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" - def __init__(self, path): + def __init__(self, path: str): resource.Resource.__init__(self) self.url = path - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: return redirectTo(self.url.encode("ascii"), request) - def getChild(self, name, request): + def getChild(self, name: str, request: Request) -> resource.Resource: if len(name) == 0: return self # select ourselves as the child to render return resource.Resource.getChild(self, name, request) @@ -529,7 +545,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request): + def render_OPTIONS(self, request: Request) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -537,7 +553,7 @@ class OptionsResource(resource.Resource): return b"" - def getChildWithDefault(self, path, request): + def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: if request.method == b"OPTIONS": return self # select ourselves as the child to render return resource.Resource.getChildWithDefault(self, path, request) @@ -649,7 +665,7 @@ def respond_with_json( json_object: Any, send_cors: bool = False, canonical_json: bool = True, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -696,7 +712,7 @@ def respond_with_json_bytes( code: int, json_bytes: bytes, send_cors: bool = False, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -713,7 +729,7 @@ def respond_with_json_bytes( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread( request: SynapseRequest, json_encoder: Callable[[Any], bytes], json_object: Any, -): +) -> None: """Encodes the given JSON object on a thread and then writes it to the request. @@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request): +def set_cors_headers(request: Request) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -790,14 +806,14 @@ def set_cors_headers(request: Request): ) -def respond_with_html(request: Request, code: int, html: str): +def respond_with_html(request: Request, code: int, html: str) -> None: """ Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. """ respond_with_html_bytes(request, code, html.encode("utf-8")) -def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None: """ Sends HTML (encoded as UTF-8 bytes) as the response to the given request. @@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") @@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): finish_request(request) -def set_clickjacking_protection_headers(request: Request): +def set_clickjacking_protection_headers(request: Request) -> None: """ Set headers to guard against clickjacking of embedded content. @@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: finish_request(request) -def finish_request(request: Request): +def finish_request(request: Request) -> None: """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index e543cc6e0..4ff840ca0 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -31,6 +31,7 @@ from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -726,7 +727,7 @@ class RestServlet: into the appropriate HTTP response. """ - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Register this servlet with the given HTTP server.""" patterns = getattr(self, "PATTERNS", None) if patterns: diff --git a/synapse/http/site.py b/synapse/http/site.py index 755ad5663..9f68d7e19 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Generator, Optional, Tuple, Union +from typing import Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -66,9 +66,9 @@ class SynapseRequest(Request): self, channel: HTTPChannel, site: "SynapseSite", - *args, + *args: Any, max_request_body_size: int = 1024, - **kw, + **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size @@ -557,7 +557,7 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest - def request_factory(channel, queued: bool) -> Request: + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 12b3ae120..b9bfbea21 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> int: + def render_GET(self, request: Request) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: