Merge remote-tracking branch 'upstream/release-v1.29.0'

This commit is contained in:
Tulir Asokan 2021-03-04 12:49:37 +02:00
commit adb990d8ba
126 changed files with 2399 additions and 765 deletions

View file

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Union
from twisted.internet import task
from twisted.internet import address, task
from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
pass
def get_request_uri(request: IRequest) -> bytes:
"""Return the full URI that was requested by the client"""
return b"%s://%s%s" % (
b"https" if request.isSecure() else b"http",
_get_requested_host(request),
# despite its name, "request.uri" is only the path and query-string.
request.uri,
)
def _get_requested_host(request: IRequest) -> bytes:
hostname = request.getHeader(b"host")
if hostname:
return hostname
# no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
hostname = host.host.encode("ascii")
if request.isSecure() and host.port == 443:
# default port for https
return hostname
if not request.isSecure() and host.port == 80:
# default port for http
return hostname
return b"%s:%i" % (
hostname,
host.port,
)
def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header.

View file

@ -289,9 +289,8 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
user_agent: Optional[str] = None,
use_proxy: bool = False,
):
"""
Args:
@ -301,8 +300,8 @@ class SimpleHttpClient:
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
"""
self.hs = hs
@ -346,8 +345,7 @@ class SimpleHttpClient:
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
http_proxy=http_proxy,
https_proxy=https_proxy,
use_proxy=use_proxy,
)
if self._ip_blacklist:
@ -751,7 +749,32 @@ class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
def _maybe_fail(self):
"""
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@ -808,13 +831,15 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
d = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size:
return defer.fail(BodyExceededMaxSize())
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import List, Optional
from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""
Args:
method: HTTP method: GET/POST/etc
@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
headers = Headers()
request_headers = Headers()
else:
headers = headers.copy()
request_headers = headers.copy()
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc)
if not headers.hasHeader(b"user-agent"):
headers.addRawHeader(b"user-agent", self.user_agent)
if not request_headers.hasHeader(b"host"):
request_headers.addRawHeader(b"host", parsed_uri.netloc)
if not request_headers.hasHeader(b"user-agent"):
request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer)
self._agent.request(method, uri, request_headers, bodyProducer)
)
return res

View file

@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
content_type_headers = headers.getRawHeaders(b"Content-Type")
if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)
c_type = c_type[0].decode("ascii") # only the first header
c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
import re
from urllib.request import getproxies_environment, proxy_bypass_environment
from zope.interface import implementer
@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created.
use_proxy (bool): Whether proxy settings should be discovered and used
from conventional environment variables.
"""
def __init__(
@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None,
bindAddress=None,
pool=None,
http_proxy=None,
https_proxy=None,
use_proxy=False,
):
_AgentBase.__init__(self, reactor, pool)
@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress
http_proxy = None
https_proxy = None
no_proxy = None
if use_proxy:
proxies = getproxies_environment()
http_proxy = proxies["http"].encode() if "http" in proxies else None
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self.no_proxy = no_proxy
self._policy_for_https = contextFactory
self._reactor = reactor
@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
parsed_uri.host.decode(),
proxies={"no": self.no_proxy},
)
if (
parsed_uri.scheme == b"http"
and self.http_proxy_endpoint
and not should_skip_proxy
):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint
request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
elif (
parsed_uri.scheme == b"https"
and self.https_proxy_endpoint
and not should_skip_proxy
):
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,

View file

@ -21,6 +21,7 @@ import logging
import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
@ -30,6 +31,7 @@ from typing import (
Iterable,
Iterator,
List,
Optional,
Pattern,
Tuple,
Union,
@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
error_code = f.value.code
error_dict = f.value.error_dict()
# mypy doesn't understand that f.check asserts the type.
exc = f.value # type: SynapseError # type: ignore
error_code = exc.code
error_dict = exc.error_dict()
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
# Only respond with an error response if we haven't already started writing,
@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
cme = f.value
# mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg
@ -142,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
@ -151,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
if isinstance(error_template, str):
@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
# At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
@ -551,7 +558,7 @@ class _ByteProducer:
request: Request,
iterator: Iterator[bytes],
):
self._request = request
self._request = request # type: Optional[Request]
self._iterator = iterator
self._paused = False
@ -563,7 +570,7 @@ class _ByteProducer:
"""
Send a list of bytes as a chunk of a response.
"""
if not data:
if not data or not self._request:
return
self._request.write(b"".join(data))

View file

@ -14,8 +14,12 @@
import contextlib
import logging
import time
from typing import Optional, Union
from typing import Optional, Type, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@ -53,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@ -92,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self):
uri = self.uri
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if the URI
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.uri`.
Returns:
The redacted URI as a string.
"""
uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
uri = self.uri.decode("ascii", errors="replace")
uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
def get_method(self):
"""Gets the method associated with the request (or placeholder if not
method has yet been received).
def get_method(self) -> str:
"""Gets the method associated with the request (or placeholder if method
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.
Returns:
str
The request method as a string.
"""
method = self.method
method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
method = self.method.decode("ascii")
return self.method.decode("ascii")
return method
def render(self, resrc):
@ -333,27 +346,78 @@ class SynapseRequest(Request):
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""Request object which honours proxy headers
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
information from request headers.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return (
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
.split(b",")[0]
.strip()
.decode("ascii")
# the client IP and ssl flag, as extracted from the headers.
_forwarded_for = None # type: Optional[_XForwardedForAddress]
_forwarded_https = False # type: bool
def requestReceived(self, command, path, version):
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
# headers.
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
# for now, we just use the first x-forwarded-for header. Really, we ought
# to start from the client IP address, and check whether it is trusted; if it
# is, work backwards through the headers until we find an untrusted address.
# see https://github.com/matrix-org/synapse/issues/9471
self._forwarded_for = _XForwardedForAddress(
headers[0].split(b",")[0].strip().decode("ascii")
)
# if we got an x-forwarded-for header, also look for an x-forwarded-proto header
header = self.getHeader(b"x-forwarded-proto")
if header is not None:
self._forwarded_https = header.lower() == b"https"
else:
# this is done largely for backwards-compatibility so that people that
# haven't set an x-forwarded-proto header don't get a redirect loop.
logger.warning(
"forwarded request lacks an x-forwarded-proto header: assuming https"
)
self._forwarded_https = True
def isSecure(self):
if self._forwarded_https:
return True
return super().isSecure()
def getClientIP(self) -> str:
"""
Return the IP address of the client who submitted this request.
This method is deprecated. Use getClientAddress() instead.
"""
if self._forwarded_for is not None:
return self._forwarded_for.host
return super().getClientIP()
def getClientAddress(self) -> IAddress:
"""
Return the address of the client who submitted this request.
"""
if self._forwarded_for is not None:
return self._forwarded_for
return super().getClientAddress()
@implementer(IAddress)
@attr.s(frozen=True, slots=True)
class _XForwardedForAddress:
host = attr.ib(type=str)
class SynapseSite(Site):
"""
@ -377,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None
proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")