Merge remote-tracking branch 'upstream/release-v1.44'

This commit is contained in:
Tulir Asokan 2021-09-29 16:08:42 +03:00
commit 8631aaeb5a
243 changed files with 3908 additions and 2190 deletions

View file

@ -322,8 +322,11 @@ class SimpleHttpClient:
self.user_agent = user_agent or hs.version_string
self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
if hs.config.server.user_agent_suffix:
self.user_agent = "%s %s" % (
self.user_agent,
hs.config.server.user_agent_suffix,
)
# We use this for our body producers to ensure that they use the correct
# reactor.

View file

@ -66,7 +66,7 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
@ -465,8 +465,9 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if (
self.hs.config.federation_domain_whitelist is not None
and request.destination not in self.hs.config.federation_domain_whitelist
self.hs.config.federation.federation_domain_whitelist is not None
and request.destination
not in self.hs.config.federation.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@ -553,20 +554,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
# To preserve the logging context, the timeout is treated
# in a similar way to `defer.gatherResults`:
# * Each logging context-preserving fork is wrapped in
# `run_in_background`. In this case there is only one,
# since the timeout fork is not logging-context aware.
# * The `Deferred` that joins the forks back together is
# wrapped in `make_deferred_yieldable` to restore the
# logging context regardless of the path taken.
request_deferred = run_in_background(
self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
response = await request_deferred
response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
@ -1177,7 +1187,7 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
return (length, headers)
return length, headers
def _flatten_response_never_received(e):

View file

@ -21,7 +21,6 @@ import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
Awaitable,
@ -37,7 +36,7 @@ from typing import (
)
import jinja2
from canonicaljson import iterencode_canonical_json
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer
@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.static import File
from twisted.web.util import redirectTo
from synapse.api.errors import (
@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import preserve_fn
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
request: Request,
request: SynapseRequest,
code: int,
response_object: Any,
):
@ -561,9 +561,17 @@ class _ByteProducer:
self._iterator = iterator
self._paused = False
# Register the producer and start producing data.
self._request.registerProducer(self, True)
self.resumeProducing()
try:
self._request.registerProducer(self, True)
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
# We drop our references to data we'll not use.
self._request = None
self._iterator = iter(())
else:
# Start producing if `registerProducer` was successful
self.resumeProducing()
def _send_data(self, data: List[bytes]) -> None:
"""
@ -620,16 +628,15 @@ class _ByteProducer:
self._request = None
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
def _encode_json_bytes(json_object: Any) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
for chunk in json_encoder.iterencode(json_object):
yield chunk.encode("utf-8")
return json_encoder.encode(json_object).encode("utf-8")
def respond_with_json(
request: Request,
request: SynapseRequest,
code: int,
json_object: Any,
send_cors: bool = False,
@ -659,7 +666,7 @@ def respond_with_json(
return None
if canonical_json:
encoder = iterencode_canonical_json
encoder = encode_canonical_json
else:
encoder = _encode_json_bytes
@ -670,7 +677,9 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
_ByteProducer(request, encoder(json_object))
run_in_background(
_async_write_json_to_request_in_thread, request, encoder, json_object
)
return NOT_DONE_YET
@ -706,15 +715,56 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
# note that this is zero-copy (the bytesio shares a copy-on-write buffer with
# the original `bytes`).
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
_write_bytes_to_request(request, json_bytes)
return NOT_DONE_YET
async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
"""Encodes the given JSON object on a thread and then writes it to the
request.
This is done so that encoding large JSON objects doesn't block the reactor
thread.
Note: We don't use JsonEncoder.iterencode here as that falls back to the
Python implementation (rather than the C backend), which is *much* more
expensive.
"""
json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
_write_bytes_to_request(request, json_str)
def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
"""Writes the bytes to the request using an appropriate producer.
Note: This should be used instead of `Request.write` to correctly handle
large response bodies.
"""
# The problem with dumping all of the response into the `Request` object at
# once (via `Request.write`) is that doing so starts the timeout for the
# next request to be received: so if it takes longer than 60s to stream back
# the response to the client, the client never gets it.
#
# The correct solution is to use a Producer; then the timeout is only
# started once all of the content is sent over the TCP connection.
# To make sure we don't write all of the bytes at once we split it up into
# chunks.
chunk_size = 4096
bytes_generator = chunk_seq(bytes_to_write, chunk_size)
# We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
# unit tests can't cope with being given a pull producer.
_ByteProducer(request, bytes_generator)
def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can
use this API

View file

@ -14,14 +14,15 @@
import contextlib
import logging
import time
from typing import Optional, Tuple, Union
from typing import Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
def __init__(
self,
channel: HTTPChannel,
site: "SynapseSite",
*args,
max_request_body_size: int = 1024,
**kw,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
@ -83,13 +92,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
self.finish_time: Optional[float] = None
def __repr__(self):
def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@ -97,10 +106,10 @@ class SynapseRequest(Request):
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
self.synapse_site.site_tag,
)
def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@ -139,7 +148,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self):
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@ -205,7 +214,7 @@ class SynapseRequest(Request):
return None, None
def render(self, resrc):
def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
site_tag=self.site.site_tag,
site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
@ -228,7 +237,7 @@ class SynapseRequest(Request):
)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@ -282,7 +291,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
def finish(self):
def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@ -295,7 +304,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@ -327,7 +336,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name):
def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@ -346,17 +355,19 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method()
)
self.site.access_logger.debug(
self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
)
def _finished_processing(self):
def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@ -386,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity:
requester = f"{authenticated_entity}|{requester}"
self.site.access_logger.log(
self.synapse_site.access_logger.log(
log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
self.synapse_site.site_tag,
requester,
processing_time,
response_send_time,
@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
def requestReceived(self, command, path, version):
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
def isSecure(self):
def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@ -520,7 +531,7 @@ class SynapseSite(Site):
site_tag: str,
config: ListenerConfig,
resource: IResource,
server_version_string,
server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
):
@ -540,19 +551,23 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
self.reactor = reactor
assert config.http_options is not None
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request:
def request_factory(channel, queued: bool) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
def log(self, request: SynapseRequest) -> None:
pass