diff --git a/changelog.d/10867.misc b/changelog.d/10867.misc new file mode 100644 index 000000000..01e51fbc6 --- /dev/null +++ b/changelog.d/10867.misc @@ -0,0 +1 @@ +Add type hints to `synapse.http.site`. diff --git a/synapse/http/site.py b/synapse/http/site.py index c665a9d5d..dd4c749e1 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -21,7 +21,7 @@ from zope.interface import implementer from twisted.internet.interfaces import IAddress, IReactorTime from twisted.python.failure import Failure -from twisted.web.resource import IResource +from twisted.web.resource import IResource, Resource from twisted.web.server import Request, Site from synapse.config.server import ListenerConfig @@ -61,7 +61,7 @@ class SynapseRequest(Request): logcontext: the log context for this request """ - def __init__(self, channel, *args, max_request_body_size=1024, **kw): + def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw): Request.__init__(self, channel, *args, **kw) self._max_request_body_size = max_request_body_size self.site: SynapseSite = channel.site @@ -83,13 +83,13 @@ class SynapseRequest(Request): self._is_processing = False # the time when the asynchronous request handler completed its processing - self._processing_finished_time = None + self._processing_finished_time: Optional[float] = None # what time we finished sending the response to the client (or the connection # dropped) - self.finish_time = None + self.finish_time: Optional[float] = None - def __repr__(self): + def __repr__(self) -> str: # We overwrite this so that we don't log ``access_token`` return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( self.__class__.__name__, @@ -100,7 +100,7 @@ class SynapseRequest(Request): self.site.site_tag, ) - def handleContentChunk(self, data): + def handleContentChunk(self, data: bytes) -> None: # we should have a `content` by now. assert self.content, "handleContentChunk() called before gotLength()" if self.content.tell() + len(data) > self._max_request_body_size: @@ -139,7 +139,7 @@ class SynapseRequest(Request): # If there's no authenticated entity, it was the requester. self.logcontext.request.authenticated_entity = authenticated_entity or requester - def get_request_id(self): + def get_request_id(self) -> str: return "%s-%i" % (self.get_method(), self.request_seq) def get_redacted_uri(self) -> str: @@ -205,7 +205,7 @@ class SynapseRequest(Request): return None, None - def render(self, resrc): + def render(self, resrc: Resource) -> None: # this is called once a Resource has been found to serve the request; in our # case the Resource in question will normally be a JsonResource. @@ -282,7 +282,7 @@ class SynapseRequest(Request): if self.finish_time is not None: self._finished_processing() - def finish(self): + def finish(self) -> None: """Called when all response data has been written to this Request. Overrides twisted.web.server.Request.finish to record the finish time and do @@ -295,7 +295,7 @@ class SynapseRequest(Request): with PreserveLoggingContext(self.logcontext): self._finished_processing() - def connectionLost(self, reason): + def connectionLost(self, reason: Union[Failure, Exception]) -> None: """Called when the client connection is closed before the response is written. Overrides twisted.web.server.Request.connectionLost to record the finish time and @@ -327,7 +327,7 @@ class SynapseRequest(Request): if not self._is_processing: self._finished_processing() - def _started_processing(self, servlet_name): + def _started_processing(self, servlet_name: str) -> None: """Record the fact that we are processing this request. This will log the request's arrival. Once the request completes, @@ -354,9 +354,11 @@ class SynapseRequest(Request): self.get_redacted_uri(), ) - def _finished_processing(self): + def _finished_processing(self) -> None: """Log the completion of this request and update the metrics""" assert self.logcontext is not None + assert self.finish_time is not None + usage = self.logcontext.get_resource_usage() if self._processing_finished_time is None: @@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest): _forwarded_for: "Optional[_XForwardedForAddress]" = None _forwarded_https: bool = False - def requestReceived(self, command, path, version): + def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None: # this method is called by the Channel once the full request has been # received, to dispatch the request to a resource. # We can use it to set the IP address and protocol according to the @@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest): self._process_forwarded_headers() return super().requestReceived(command, path, version) - def _process_forwarded_headers(self): + def _process_forwarded_headers(self) -> None: headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") if not headers: return @@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest): ) self._forwarded_https = True - def isSecure(self): + def isSecure(self) -> bool: if self._forwarded_https: return True return super().isSecure() @@ -545,14 +547,16 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest - def request_factory(channel, queued) -> Request: + def request_factory(channel, queued: bool) -> Request: return request_class( - channel, max_request_body_size=max_request_body_size, queued=queued + channel, + max_request_body_size=max_request_body_size, + queued=queued, ) self.requestFactory = request_factory # type: ignore self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") - def log(self, request): + def log(self, request: SynapseRequest) -> None: pass