Merge remote-tracking branch 'upstream/release-v1.24.0'

This commit is contained in:
Tulir Asokan 2020-12-02 16:30:50 +02:00
commit 6e2f942da1
152 changed files with 4450 additions and 2462 deletions

View file

@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import urllib.parse
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
) -> bool:
"""
Compares an IP address to allowed and disallowed IP sets.
Args:
ip_address (netaddr.IPAddress)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
ip_address: The IP address to check
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
Returns:
True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
def __init__(self, reactor, ip_whitelist, ip_blacklist):
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
):
"""
Args:
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
reactor: The twisted reactor.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def resolveHostName(self, recv, hostname, portNumber=0):
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
r = recv()
addresses = []
addresses = [] # type: List[IAddress]
def _callback():
def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress):
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
def addressResolved(address):
def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
def resolutionComplete():
def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
def __init__(
self,
agent: IAgent,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
):
"""
Args:
agent (twisted.web.client.Agent): The Agent to wrap.
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
agent: The Agent to wrap.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@ -226,24 +256,24 @@ class SimpleHttpClient:
def __init__(
self,
hs,
treq_args={},
ip_whitelist=None,
ip_blacklist=None,
http_proxy=None,
https_proxy=None,
user_agent=None,
hs: "HomeServer",
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
user_agent: Optional[str] = None,
):
"""
Args:
hs (synapse.server.HomeServer)
treq_args (dict): Extra keyword arguments to be given to treq.request.
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
hs
treq_args: Extra keyword arguments to be given to treq.request.
ip_blacklist: The IP addresses that are blacklisted that
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy (bytes): proxy server to use for http connections. host[:port]
https_proxy (bytes): proxy server to use for https connections. host[:port]
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@ -307,7 +337,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@ -398,7 +427,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
args: Mapping[str, Union[str, List[str]]] = {},
args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@ -423,9 +452,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
"utf8"
)
query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -433,7 +460,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@ -480,7 +507,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@ -496,7 +523,10 @@ class SimpleHttpClient:
)
async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@ -517,7 +547,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@ -526,7 +556,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
args: QueryParams = {},
args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@ -547,9 +577,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@ -559,7 +589,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@ -575,7 +605,10 @@ class SimpleHttpClient:
)
async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@ -593,13 +626,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@ -642,7 +675,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@ -650,12 +683,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@ -669,7 +703,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@ -697,18 +731,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@ -722,7 +754,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@ -733,35 +765,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
max_size: The maximum file size to allow.
Returns:
A Deferred which resolves to the length of the read body.
"""
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Args:
args: The query arguments, a mapping of string to string or list of strings.
def encode_urlencode_arg(arg):
if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
Returns:
The query arguments encoded as bytes.
"""
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
query_str = urllib.parse.urlencode(encoded_args, True)
return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):

View file

@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import List
import urllib.parse
from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.internet.interfaces import (
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
)
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
reactor (IReactor): twisted reactor to use for underlying requests
reactor: twisted reactor to use for underlying requests
tls_client_options_factory (FederationPolicyForHTTPS|None):
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
user_agent (bytes):
user_agent:
The user agent header to use for federation requests.
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
_srv_resolver:
SrvResolver implementation to use for looking up SRV records. None
to use a default implementation.
_well_known_resolver (WellKnownResolver|None):
_well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
reactor,
tls_client_options_factory,
user_agent,
_srv_resolver=None,
_well_known_resolver=None,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
method: HTTP method: GET/POST/etc
uri: Absolute URI to be retrieved
headers:
HTTP headers to send with the request, or None to send no extra headers.
bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
# There must be a valid hostname.
assert parsed_uri.hostname
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
def __init__(self, reactor, tls_client_options_factory, srv_resolver):
def __init__(
self,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: Optional[SrvResolver],
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
reactor (IReactor)
tls_client_options_factory (ClientTLSOptionsFactory|None):
reactor: twisted reactor to use for underlying requests
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver): The SRV resolver to use
parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
to connect to.
srv_resolver: The SRV resolver to use
parsed_uri: The parsed URI that we're wanting to connect to.
"""
def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
def __init__(
self,
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: SrvResolver,
parsed_uri: URI,
):
self._reactor = reactor
self._parsed_uri = parsed_uri
@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory):
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory):
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
def _is_ip_literal(host):
def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
host (bytes)
host: The host name to check
Returns:
bool
True if the hostname is an IP address literal.
"""
host = host.decode("ascii")
host_str = host.decode("ascii")
try:
IPAddress(host)
IPAddress(host_str)
return True
except AddrFormatError:
return False

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import time
@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
reactor,
agent,
user_agent,
well_known_cache=None,
had_well_known_cache=None,
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
well_known_cache: Optional[TTLCache] = None,
had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
) # type: Tuple[Optional[bytes], float]
) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:

View file

@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
import urllib
import urllib.parse
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.client import (
BlacklistingAgentWrapper,
IPBlacklistingResolver,
encode_query_args,
readBodyToFile,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]]
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
method = attr.ib()
method = attr.ib(type=str)
"""HTTP method
:type: str
"""
path = attr.ib()
path = attr.ib(type=str)
"""HTTP path
:type: str
"""
destination = attr.ib()
destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
:type: str"""
"""
json = attr.ib(default=None)
json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
:type: dict|None
"""
json_callback = attr.ib(default=None)
json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
:type: func|None
"""
query = attr.ib(default=None)
query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
:type: dict|None
"""
txn_id = attr.ib(default=None)
txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
:type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
def get_json(self):
def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
):
) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
dict: parsed JSON response
The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
self, request, try_trailing_slash_on_400=False, **send_request_args
):
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
**send_request_args
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
request (MatrixFederationRequest): details of request to be sent
try_trailing_slash_on_400 (bool): Whether on receiving a 400
request: details of request to be sent
try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
send_request_args (Dict): A dictionary of arguments to pass to
`_send_request()`.
send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
Dict: Parsed JSON response body.
Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
request,
retry_on_dns_fail=True,
timeout=None,
long_retries=False,
ignore_backoff=False,
backoff_on_404=False,
):
request: MatrixFederationRequest,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
long_retries: bool = False,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
) -> IResponse:
"""
Sends a request to the given server.
Args:
request (MatrixFederationRequest): details of request to be sent
request: details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers
retry_on_dns_fail: true if the request should be retied on DNS failures
timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
long_retries (bool): whether to use the long retry algorithm.
long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_404: Back off if we get a 404
Returns:
twisted.web.client.Response: resolves with the HTTP
response object on success.
Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
headers_dict = {}
headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
)
) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
e = HttpResponseException(response.code, response_phrase, body)
exc = HttpResponseException(
response.code, response_phrase, body
)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
raise RequestSendFailed(e, can_retry=True) from e
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise e
raise exc
break
except RequestSendFailed as e:
@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None
):
self,
destination: Optional[bytes],
method: bytes,
url_bytes: bytes,
content: Optional[JsonDict] = None,
destination_is: Optional[bytes] = None,
) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
destination (bytes|None): The destination homeserver of the request.
destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
url_bytes (bytes): The URI path of the request
content (object): The body of the request
destination_is (bytes): As 'destination', but if the destination is an
method: The HTTP method of the request
url_bytes: The URI path of the request
content: The body of the request
destination_is: As 'destination', but if the destination is an
identity server
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
destination,
path,
args={},
data={},
json_data_callback=None,
long_retries=False,
timeout=None,
ignore_backoff=False,
backoff_on_404=False,
try_trailing_slash_on_400=False,
):
destination: str,
path: str,
args: Optional[QueryArgs] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): query params
data (dict): A dict containing the data that will be used as
destination: The remote server to send the HTTP request to.
path: The HTTP path.
args: query params
data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
json_data_callback: A callable returning the dict to
use as the request body.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
destination,
path,
data={},
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
destination: str,
path: str,
data: Optional[JsonDict] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
Args:
destination (str): The remote server to send the HTTP request
to.
destination: The remote server to send the HTTP request to.
path (str): The HTTP path.
path: The HTTP path.
data (dict): A dict containing the data that will be used as
data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data and
ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
args (dict): query params
args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
destination,
path,
args=None,
retry_on_dns_fail=True,
timeout=None,
ignore_backoff=False,
try_trailing_slash_on_400=False,
):
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
destination: The remote server to send the HTTP request to.
path (str): The HTTP path.
path: The HTTP path.
args (dict|None): A dictionary used to create query strings, defaults to
args: A dictionary used to create query strings, defaults to
None.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
destination,
path,
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
destination: str,
path: str,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
destination: The remote server to send the HTTP request to.
path: The HTTP path.
long_retries (bool): whether to use the long retry algorithm. See
long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
timeout (int|None): number of milliseconds to wait for the response.
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff (bool): true to ignore the historical backoff data and
ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
args (dict): query params
args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
destination,
path,
destination: str,
path: str,
output_stream,
args={},
retry_on_dns_fail=True,
max_size=None,
ignore_backoff=False,
):
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
tuple[int, dict]: Resolves with an (int,dict) tuple of
Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
d = _readBodyToFile(response, output_stream, max_size)
d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
def check_content_type_is_json(headers):
def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
headers (twisted.web.http_headers.Headers): headers to check
headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers):
),
can_retry=False,
)
def encode_query_args(args):
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes.encode("utf8")

View file

@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
request, error_code, error_dict, send_cors=True,
)
@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
pretty_print: Whether to include indentation and line-breaks in the
resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@ -615,13 +607,10 @@ def respond_with_json(
)
return None
if pretty_print:
encoder = iterencode_pretty_printed_json
if canonical_json:
encoder = iterencode_canonical_json
else:
if canonical_json:
encoder = iterencode_canonical_json
else:
encoder = _encode_json_bytes
encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _request_user_agent_is_curl(request: Request) -> bool:
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
return False