Add types to http.site (#10867)

This commit is contained in:
Erik Johnston 2021-09-21 17:41:27 +01:00 committed by GitHub
parent ebd8baf61f
commit b25a494779
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 18 deletions

1
changelog.d/10867.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.http.site`.

View File

@ -21,7 +21,7 @@ from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.resource import IResource from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request logcontext: the log context for this request
""" """
def __init__(self, channel, *args, max_request_body_size=1024, **kw): def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site self.site: SynapseSite = channel.site
@ -83,13 +83,13 @@ class SynapseRequest(Request):
self._is_processing = False self._is_processing = False
# the time when the asynchronous request handler completed its processing # the time when the asynchronous request handler completed its processing
self._processing_finished_time = None self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection # what time we finished sending the response to the client (or the connection
# dropped) # dropped)
self.finish_time = None self.finish_time: Optional[float] = None
def __repr__(self): def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__, self.__class__.__name__,
@ -100,7 +100,7 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
def handleContentChunk(self, data): def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now. # we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()" assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size: if self.content.tell() + len(data) > self._max_request_body_size:
@ -139,7 +139,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester. # If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str: def get_redacted_uri(self) -> str:
@ -205,7 +205,7 @@ class SynapseRequest(Request):
return None, None return None, None
def render(self, resrc): def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
@ -282,7 +282,7 @@ class SynapseRequest(Request):
if self.finish_time is not None: if self.finish_time is not None:
self._finished_processing() self._finished_processing()
def finish(self): def finish(self) -> None:
"""Called when all response data has been written to this Request. """Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do Overrides twisted.web.server.Request.finish to record the finish time and do
@ -295,7 +295,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
self._finished_processing() self._finished_processing()
def connectionLost(self, reason): def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written. """Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and Overrides twisted.web.server.Request.connectionLost to record the finish time and
@ -327,7 +327,7 @@ class SynapseRequest(Request):
if not self._is_processing: if not self._is_processing:
self._finished_processing() self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request. """Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes, This will log the request's arrival. Once the request completes,
@ -354,9 +354,11 @@ class SynapseRequest(Request):
self.get_redacted_uri(), self.get_redacted_uri(),
) )
def _finished_processing(self): def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics""" """Log the completion of this request and update the metrics"""
assert self.logcontext is not None assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage() usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None: if self._processing_finished_time is None:
@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None _forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False _forwarded_https: bool = False
def requestReceived(self, command, path, version): def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been # this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource. # received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the # We can use it to set the IP address and protocol according to the
@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers() self._process_forwarded_headers()
return super().requestReceived(command, path, version) return super().requestReceived(command, path, version)
def _process_forwarded_headers(self): def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers: if not headers:
return return
@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest):
) )
self._forwarded_https = True self._forwarded_https = True
def isSecure(self): def isSecure(self) -> bool:
if self._forwarded_https: if self._forwarded_https:
return True return True
return super().isSecure() return super().isSecure()
@ -545,14 +547,16 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request: def request_factory(channel, queued: bool) -> Request:
return request_class( return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued channel,
max_request_body_size=max_request_body_size,
queued=queued,
) )
self.requestFactory = request_factory # type: ignore self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")
def log(self, request): def log(self, request: SynapseRequest) -> None:
pass pass