mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-18 02:00:15 -04:00
Merge remote-tracking branch 'upstream/release-v1.29.0'
This commit is contained in:
commit
adb990d8ba
126 changed files with 2399 additions and 765 deletions
|
@ -14,8 +14,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from twisted.internet import task
|
||||
from twisted.internet import address, task
|
||||
from twisted.web.client import FileBodyProducer
|
||||
from twisted.web.iweb import IRequest
|
||||
|
||||
|
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
|
|||
pass
|
||||
|
||||
|
||||
def get_request_uri(request: IRequest) -> bytes:
|
||||
"""Return the full URI that was requested by the client"""
|
||||
return b"%s://%s%s" % (
|
||||
b"https" if request.isSecure() else b"http",
|
||||
_get_requested_host(request),
|
||||
# despite its name, "request.uri" is only the path and query-string.
|
||||
request.uri,
|
||||
)
|
||||
|
||||
|
||||
def _get_requested_host(request: IRequest) -> bytes:
|
||||
hostname = request.getHeader(b"host")
|
||||
if hostname:
|
||||
return hostname
|
||||
|
||||
# no Host header, use the address/port that the request arrived on
|
||||
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
|
||||
|
||||
hostname = host.host.encode("ascii")
|
||||
|
||||
if request.isSecure() and host.port == 443:
|
||||
# default port for https
|
||||
return hostname
|
||||
|
||||
if not request.isSecure() and host.port == 80:
|
||||
# default port for http
|
||||
return hostname
|
||||
|
||||
return b"%s:%i" % (
|
||||
hostname,
|
||||
host.port,
|
||||
)
|
||||
|
||||
|
||||
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
||||
"""Return the last User-Agent header, or the given default."""
|
||||
# There could be raw utf-8 bytes in the User-Agent header.
|
||||
|
|
|
@ -289,9 +289,8 @@ class SimpleHttpClient:
|
|||
treq_args: Dict[str, Any] = {},
|
||||
ip_whitelist: Optional[IPSet] = None,
|
||||
ip_blacklist: Optional[IPSet] = None,
|
||||
http_proxy: Optional[bytes] = None,
|
||||
https_proxy: Optional[bytes] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
use_proxy: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -301,8 +300,8 @@ class SimpleHttpClient:
|
|||
we may not request.
|
||||
ip_whitelist: The whitelisted IP addresses, that we can
|
||||
request if it were otherwise caught in a blacklist.
|
||||
http_proxy: proxy server to use for http connections. host[:port]
|
||||
https_proxy: proxy server to use for https connections. host[:port]
|
||||
use_proxy: Whether proxy settings should be discovered and used
|
||||
from conventional environment variables.
|
||||
"""
|
||||
self.hs = hs
|
||||
|
||||
|
@ -346,8 +345,7 @@ class SimpleHttpClient:
|
|||
connectTimeout=15,
|
||||
contextFactory=self.hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
http_proxy=http_proxy,
|
||||
https_proxy=https_proxy,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if self._ip_blacklist:
|
||||
|
@ -751,7 +749,32 @@ class BodyExceededMaxSize(Exception):
|
|||
"""The maximum allowed size of the HTTP body was exceeded."""
|
||||
|
||||
|
||||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which immediately errors upon receiving data."""
|
||||
|
||||
def __init__(self, deferred: defer.Deferred):
|
||||
self.deferred = deferred
|
||||
|
||||
def _maybe_fail(self):
|
||||
"""
|
||||
Report a max size exceed error and disconnect the first time this is called.
|
||||
"""
|
||||
if not self.deferred.called:
|
||||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
self.transport.abortConnection()
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
|
||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||
|
||||
def __init__(
|
||||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
||||
):
|
||||
|
@ -808,13 +831,15 @@ def read_body_with_max_size(
|
|||
Returns:
|
||||
A Deferred which resolves to the length of the read body.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
|
||||
# If the Content-Length header gives a size larger than the maximum allowed
|
||||
# size, do not bother downloading the body.
|
||||
if max_size is not None and response.length != UNKNOWN_LENGTH:
|
||||
if response.length > max_size:
|
||||
return defer.fail(BodyExceededMaxSize())
|
||||
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
||||
return d
|
||||
|
||||
d = defer.Deferred()
|
||||
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
|
||||
return d
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import List, Optional
|
||||
from typing import Any, Generator, List, Optional
|
||||
|
||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||
from zope.interface import implementer
|
||||
|
@ -116,7 +116,7 @@ class MatrixFederationAgent:
|
|||
uri: bytes,
|
||||
headers: Optional[Headers] = None,
|
||||
bodyProducer: Optional[IBodyProducer] = None,
|
||||
) -> defer.Deferred:
|
||||
) -> Generator[defer.Deferred, Any, defer.Deferred]:
|
||||
"""
|
||||
Args:
|
||||
method: HTTP method: GET/POST/etc
|
||||
|
@ -177,17 +177,17 @@ class MatrixFederationAgent:
|
|||
# We need to make sure the host header is set to the netloc of the
|
||||
# server and that a user-agent is provided.
|
||||
if headers is None:
|
||||
headers = Headers()
|
||||
request_headers = Headers()
|
||||
else:
|
||||
headers = headers.copy()
|
||||
request_headers = headers.copy()
|
||||
|
||||
if not headers.hasHeader(b"host"):
|
||||
headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||
if not headers.hasHeader(b"user-agent"):
|
||||
headers.addRawHeader(b"user-agent", self.user_agent)
|
||||
if not request_headers.hasHeader(b"host"):
|
||||
request_headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||
if not request_headers.hasHeader(b"user-agent"):
|
||||
request_headers.addRawHeader(b"user-agent", self.user_agent)
|
||||
|
||||
res = yield make_deferred_yieldable(
|
||||
self._agent.request(method, uri, headers, bodyProducer)
|
||||
self._agent.request(method, uri, request_headers, bodyProducer)
|
||||
)
|
||||
|
||||
return res
|
||||
|
|
|
@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
|
|||
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
||||
|
||||
"""
|
||||
c_type = headers.getRawHeaders(b"Content-Type")
|
||||
if c_type is None:
|
||||
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||
if content_type_headers is None:
|
||||
raise RequestSendFailed(
|
||||
RuntimeError("No Content-Type header received from remote server"),
|
||||
can_retry=False,
|
||||
)
|
||||
|
||||
c_type = c_type[0].decode("ascii") # only the first header
|
||||
c_type = content_type_headers[0].decode("ascii") # only the first header
|
||||
val, options = cgi.parse_header(c_type)
|
||||
if val != "application/json":
|
||||
raise RequestSendFailed(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from urllib.request import getproxies_environment, proxy_bypass_environment
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
|
@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
|
|||
|
||||
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||
non-persistent pool instance will be created.
|
||||
|
||||
use_proxy (bool): Whether proxy settings should be discovered and used
|
||||
from conventional environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
|
|||
connectTimeout=None,
|
||||
bindAddress=None,
|
||||
pool=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
use_proxy=False,
|
||||
):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
|
@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
|
|||
if bindAddress is not None:
|
||||
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||
|
||||
http_proxy = None
|
||||
https_proxy = None
|
||||
no_proxy = None
|
||||
if use_proxy:
|
||||
proxies = getproxies_environment()
|
||||
http_proxy = proxies["http"].encode() if "http" in proxies else None
|
||||
https_proxy = proxies["https"].encode() if "https" in proxies else None
|
||||
no_proxy = proxies["no"] if "no" in proxies else None
|
||||
|
||||
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
|
|||
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self.no_proxy = no_proxy
|
||||
|
||||
self._policy_for_https = contextFactory
|
||||
self._reactor = reactor
|
||||
|
||||
|
@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
|
|||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
request_path = parsed_uri.originForm
|
||||
|
||||
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||
should_skip_proxy = False
|
||||
if self.no_proxy is not None:
|
||||
should_skip_proxy = proxy_bypass_environment(
|
||||
parsed_uri.host.decode(),
|
||||
proxies={"no": self.no_proxy},
|
||||
)
|
||||
|
||||
if (
|
||||
parsed_uri.scheme == b"http"
|
||||
and self.http_proxy_endpoint
|
||||
and not should_skip_proxy
|
||||
):
|
||||
# Cache *all* connections under the same key, since we are only
|
||||
# connecting to a single destination, the proxy:
|
||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||
endpoint = self.http_proxy_endpoint
|
||||
request_path = uri
|
||||
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||
elif (
|
||||
parsed_uri.scheme == b"https"
|
||||
and self.https_proxy_endpoint
|
||||
and not should_skip_proxy
|
||||
):
|
||||
endpoint = HTTPConnectProxyEndpoint(
|
||||
self.proxy_reactor,
|
||||
self.https_proxy_endpoint,
|
||||
|
|
|
@ -21,6 +21,7 @@ import logging
|
|||
import types
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from inspect import isawaitable
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
Any,
|
||||
|
@ -30,6 +31,7 @@ from typing import (
|
|||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
Union,
|
||||
|
@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
|||
"""Sends a JSON error response to clients."""
|
||||
|
||||
if f.check(SynapseError):
|
||||
error_code = f.value.code
|
||||
error_dict = f.value.error_dict()
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
exc = f.value # type: SynapseError # type: ignore
|
||||
error_code = exc.code
|
||||
error_dict = exc.error_dict()
|
||||
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||
else:
|
||||
error_code = 500
|
||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||
|
@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
|||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
|
||||
# Only respond with an error response if we haven't already started writing,
|
||||
|
@ -128,7 +132,8 @@ def return_html_error(
|
|||
`{msg}` placeholders), or a jinja2 template
|
||||
"""
|
||||
if f.check(CodeMessageException):
|
||||
cme = f.value
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
cme = f.value # type: CodeMessageException # type: ignore
|
||||
code = cme.code
|
||||
msg = cme.msg
|
||||
|
||||
|
@ -142,7 +147,7 @@ def return_html_error(
|
|||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
else:
|
||||
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
@ -151,7 +156,7 @@ def return_html_error(
|
|||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(error_template, str):
|
||||
|
@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
|||
raw_callback_return = method_handler(request)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
if isawaitable(raw_callback_return):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return # type: ignore
|
||||
|
@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
|
|||
A tuple of the callback to use, the name of the servlet, and the
|
||||
key word arguments to pass to the callback
|
||||
"""
|
||||
# At this point the path must be bytes.
|
||||
request_path_bytes = request.path # type: bytes # type: ignore
|
||||
request_path = request_path_bytes.decode("ascii")
|
||||
# Treat HEAD requests as GET requests.
|
||||
request_path = request.path.decode("ascii")
|
||||
request_method = request.method
|
||||
if request_method == b"HEAD":
|
||||
request_method = b"GET"
|
||||
|
@ -551,7 +558,7 @@ class _ByteProducer:
|
|||
request: Request,
|
||||
iterator: Iterator[bytes],
|
||||
):
|
||||
self._request = request
|
||||
self._request = request # type: Optional[Request]
|
||||
self._iterator = iterator
|
||||
self._paused = False
|
||||
|
||||
|
@ -563,7 +570,7 @@ class _ByteProducer:
|
|||
"""
|
||||
Send a list of bytes as a chunk of a response.
|
||||
"""
|
||||
if not data:
|
||||
if not data or not self._request:
|
||||
return
|
||||
self._request.write(b"".join(data))
|
||||
|
||||
|
|
|
@ -14,8 +14,12 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.server import Request, Site
|
||||
|
||||
|
@ -53,7 +57,7 @@ class SynapseRequest(Request):
|
|||
|
||||
def __init__(self, channel, *args, **kw):
|
||||
Request.__init__(self, channel, *args, **kw)
|
||||
self.site = channel.site
|
||||
self.site = channel.site # type: SynapseSite
|
||||
self._channel = channel # this is used by the tests
|
||||
self.start_time = 0.0
|
||||
|
||||
|
@ -92,25 +96,34 @@ class SynapseRequest(Request):
|
|||
def get_request_id(self):
|
||||
return "%s-%i" % (self.get_method(), self.request_seq)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
uri = self.uri
|
||||
def get_redacted_uri(self) -> str:
|
||||
"""Gets the redacted URI associated with the request (or placeholder if the URI
|
||||
has not yet been received).
|
||||
|
||||
Note: This is necessary as the placeholder value in twisted is str
|
||||
rather than bytes, so we need to sanitise `self.uri`.
|
||||
|
||||
Returns:
|
||||
The redacted URI as a string.
|
||||
"""
|
||||
uri = self.uri # type: Union[bytes, str]
|
||||
if isinstance(uri, bytes):
|
||||
uri = self.uri.decode("ascii", errors="replace")
|
||||
uri = uri.decode("ascii", errors="replace")
|
||||
return redact_uri(uri)
|
||||
|
||||
def get_method(self):
|
||||
"""Gets the method associated with the request (or placeholder if not
|
||||
method has yet been received).
|
||||
def get_method(self) -> str:
|
||||
"""Gets the method associated with the request (or placeholder if method
|
||||
has not yet been received).
|
||||
|
||||
Note: This is necessary as the placeholder value in twisted is str
|
||||
rather than bytes, so we need to sanitise `self.method`.
|
||||
|
||||
Returns:
|
||||
str
|
||||
The request method as a string.
|
||||
"""
|
||||
method = self.method
|
||||
method = self.method # type: Union[bytes, str]
|
||||
if isinstance(method, bytes):
|
||||
method = self.method.decode("ascii")
|
||||
return self.method.decode("ascii")
|
||||
return method
|
||||
|
||||
def render(self, resrc):
|
||||
|
@ -333,27 +346,78 @@ class SynapseRequest(Request):
|
|||
|
||||
|
||||
class XForwardedForRequest(SynapseRequest):
|
||||
def __init__(self, *args, **kw):
|
||||
SynapseRequest.__init__(self, *args, **kw)
|
||||
"""Request object which honours proxy headers
|
||||
|
||||
"""
|
||||
Add a layer on top of another request that only uses the value of an
|
||||
X-Forwarded-For header as the result of C{getClientIP}.
|
||||
Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
|
||||
information from request headers.
|
||||
"""
|
||||
|
||||
def getClientIP(self):
|
||||
"""
|
||||
@return: The client address (the first address) in the value of the
|
||||
I{X-Forwarded-For header}. If the header is not present, return
|
||||
C{b"-"}.
|
||||
"""
|
||||
return (
|
||||
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
|
||||
.split(b",")[0]
|
||||
.strip()
|
||||
.decode("ascii")
|
||||
# the client IP and ssl flag, as extracted from the headers.
|
||||
_forwarded_for = None # type: Optional[_XForwardedForAddress]
|
||||
_forwarded_https = False # type: bool
|
||||
|
||||
def requestReceived(self, command, path, version):
|
||||
# this method is called by the Channel once the full request has been
|
||||
# received, to dispatch the request to a resource.
|
||||
# We can use it to set the IP address and protocol according to the
|
||||
# headers.
|
||||
self._process_forwarded_headers()
|
||||
return super().requestReceived(command, path, version)
|
||||
|
||||
def _process_forwarded_headers(self):
|
||||
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
|
||||
if not headers:
|
||||
return
|
||||
|
||||
# for now, we just use the first x-forwarded-for header. Really, we ought
|
||||
# to start from the client IP address, and check whether it is trusted; if it
|
||||
# is, work backwards through the headers until we find an untrusted address.
|
||||
# see https://github.com/matrix-org/synapse/issues/9471
|
||||
self._forwarded_for = _XForwardedForAddress(
|
||||
headers[0].split(b",")[0].strip().decode("ascii")
|
||||
)
|
||||
|
||||
# if we got an x-forwarded-for header, also look for an x-forwarded-proto header
|
||||
header = self.getHeader(b"x-forwarded-proto")
|
||||
if header is not None:
|
||||
self._forwarded_https = header.lower() == b"https"
|
||||
else:
|
||||
# this is done largely for backwards-compatibility so that people that
|
||||
# haven't set an x-forwarded-proto header don't get a redirect loop.
|
||||
logger.warning(
|
||||
"forwarded request lacks an x-forwarded-proto header: assuming https"
|
||||
)
|
||||
self._forwarded_https = True
|
||||
|
||||
def isSecure(self):
|
||||
if self._forwarded_https:
|
||||
return True
|
||||
return super().isSecure()
|
||||
|
||||
def getClientIP(self) -> str:
|
||||
"""
|
||||
Return the IP address of the client who submitted this request.
|
||||
|
||||
This method is deprecated. Use getClientAddress() instead.
|
||||
"""
|
||||
if self._forwarded_for is not None:
|
||||
return self._forwarded_for.host
|
||||
return super().getClientIP()
|
||||
|
||||
def getClientAddress(self) -> IAddress:
|
||||
"""
|
||||
Return the address of the client who submitted this request.
|
||||
"""
|
||||
if self._forwarded_for is not None:
|
||||
return self._forwarded_for
|
||||
return super().getClientAddress()
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class _XForwardedForAddress:
|
||||
host = attr.ib(type=str)
|
||||
|
||||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
|
@ -377,7 +441,9 @@ class SynapseSite(Site):
|
|||
|
||||
assert config.http_options is not None
|
||||
proxied = config.http_options.x_forwarded
|
||||
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
|
||||
self.requestFactory = (
|
||||
XForwardedForRequest if proxied else SynapseRequest
|
||||
) # type: Type[Request]
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
self.server_version_string = server_version_string.encode("ascii")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue