Run Black. (#5482)

This commit is contained in:
Amber Brown 2019-06-20 19:32:02 +10:00 committed by GitHub
parent 7dcf984075
commit 32e7c9e7f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
376 changed files with 9142 additions and 10388 deletions

View file

@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
def __init__(self):
super(RequestTimedOutError, self).__init__(504, "Timed out")
@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>"""
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
uri
)
return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer):
@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
Workaround for https://github.com/matrix-org/synapse/issues/4003 /
https://twistedmatrix.com/trac/ticket/6528
"""
def stopProducing(self):
try:
FileBodyProducer.stopProducing(self)

View file

@ -28,6 +28,7 @@ class AdditionalResource(Resource):
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
def __init__(self, hs, handler):
"""Initialise AdditionalResource

View file

@ -103,8 +103,8 @@ class IPBlacklistingResolver(object):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
"Dropped %s from DNS resolution to %s due to blacklist" %
(ip_address, hostname)
"Dropped %s from DNS resolution to %s due to blacklist"
% (ip_address, hostname)
)
has_bad_ip = True
@ -156,7 +156,7 @@ class BlacklistingAgentWrapper(Agent):
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
h = urllib.parse.urlparse(uri.decode('ascii'))
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
ip_address = IPAddress(h.hostname)
@ -164,10 +164,7 @@ class BlacklistingAgentWrapper(Agent):
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
"Blocking access to %s due to blacklist" %
(ip_address,)
)
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
@ -206,7 +203,7 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
self.user_agent = self.user_agent.encode('ascii')
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
real_reactor = hs.get_reactor()
@ -520,8 +517,8 @@ class SimpleHttpClient(object):
resp_headers = dict(response.headers.getAllRawHeaders())
if (
b'Content-Length' in resp_headers
and int(resp_headers[b'Content-Length'][0]) > max_size
b"Content-Length" in resp_headers
and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
@ -546,18 +543,13 @@ class SimpleHttpClient(object):
# This can happen e.g. because the body is too large.
raise
except Exception as e:
raise_from(
SynapseError(
502, ("Failed to download remote body: %s" % e),
),
e
)
raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
defer.returnValue(
(
length,
resp_headers,
response.request.absoluteURI.decode('ascii'),
response.request.absoluteURI.decode("ascii"),
response.code,
)
)
@ -647,7 +639,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
if isinstance(arg, text_type):
return arg.encode('utf-8')
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:

View file

@ -31,7 +31,7 @@ def parse_server_name(server_name):
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == ']':
if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@ -43,9 +43,7 @@ def parse_server_name(server_name):
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile(
"\\A[0-9a-zA-Z.-]+\\Z",
)
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name):
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == '[':
if host[-1] != ']':
raise ValueError("Mismatched [...] in server name '%s'" % (
server_name,
))
if host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError("Server name '%s' contains invalid characters" % (
server_name,
))
raise ValueError(
"Server name '%s' contains invalid characters" % (server_name,)
)
return host, port

View file

@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
well_known_cache = TTLCache('well-known')
well_known_cache = TTLCache("well-known")
@implementer(IAgent)
@ -78,7 +78,9 @@ class MatrixFederationAgent(object):
"""
def __init__(
self, reactor, tls_client_options_factory,
self,
reactor,
tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
@ -100,9 +102,9 @@ class MatrixFederationAgent(object):
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args['contextFactory'] = _well_known_tls_policy
agent_args["contextFactory"] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
Agent(self._reactor, pool=self._pool, **agent_args),
Agent(self._reactor, pool=self._pool, **agent_args)
)
self._well_known_agent = _well_known_agent
@ -149,7 +151,7 @@ class MatrixFederationAgent(object):
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
res.tls_server_name.decode("ascii"),
res.tls_server_name.decode("ascii")
)
# make sure that the Host header is set correctly
@ -158,14 +160,14 @@ class MatrixFederationAgent(object):
else:
headers = headers.copy()
if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', res.host_header)
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
ep = LoggingHostnameEndpoint(
self._reactor, res.target_host, res.target_port,
self._reactor, res.target_host, res.target_port
)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
@ -203,21 +205,25 @@ class MatrixFederationAgent(object):
port = parsed_uri.port
if port == -1:
port = 8448
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=port,
))
defer.returnValue(
_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=port,
)
)
if parsed_uri.port != -1:
# there is an explicit port
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
))
defer.returnValue(
_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
)
)
if lookup_well_known:
# try a .well-known lookup
@ -229,8 +235,8 @@ class MatrixFederationAgent(object):
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
if b':' in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
if b":" in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
try:
well_known_port = int(well_known_port)
except ValueError:
@ -264,21 +270,27 @@ class MatrixFederationAgent(object):
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
parsed_uri.host.decode("ascii"),
target_host.decode("ascii"),
port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
target_host.decode("ascii"),
port,
parsed_uri.host.decode("ascii"),
)
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
))
defer.returnValue(
_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
)
)
@defer.inlineCallbacks
def _get_well_known(self, server_name):
@ -318,18 +330,18 @@ class MatrixFederationAgent(object):
- None if there was no .well-known file.
- INVALID_WELL_KNOWN if the .well-known was invalid
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name, )
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri),
self._well_known_agent.request(b"GET", uri)
)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code, ))
raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json.loads(body.decode('utf-8'))
parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
@ -347,8 +359,7 @@ class MatrixFederationAgent(object):
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
response.headers,
time_now=self._reactor.seconds,
response.headers, time_now=self._reactor.seconds
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
@ -364,6 +375,7 @@ class MatrixFederationAgent(object):
@implementer(IStreamClientEndpoint)
class LoggingHostnameEndpoint(object):
"""A wrapper for HostnameEndpint which logs when it connects"""
def __init__(self, reactor, host, port, *args, **kwargs):
self.host = host
self.port = port
@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object):
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
if b'no-store' in cache_controls:
if b"no-store" in cache_controls:
return 0
if b'max-age' in cache_controls:
if b"max-age" in cache_controls:
try:
max_age = int(cache_controls[b'max-age'])
max_age = int(cache_controls[b"max-age"])
return max_age
except ValueError:
pass
expires = headers.getRawHeaders(b'expires')
expires = headers.getRawHeaders(b"expires")
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time):
def _parse_cache_control(headers):
cache_controls = {}
for hdr in headers.getRawHeaders(b'cache-control', []):
for directive in hdr.split(b','):
splits = [x.strip() for x in directive.split(b'=', 1)]
for hdr in headers.getRawHeaders(b"cache-control", []):
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v

View file

@ -45,6 +45,7 @@ class Server(object):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
@ -79,9 +80,7 @@ def pick_server_from_list(server_list):
return s.host, s.port
# this should be impossible.
raise RuntimeError(
"pick_server_from_list got to end of eligible server list.",
)
raise RuntimeError("pick_server_from_list got to end of eligible server list.")
class SrvResolver(object):
@ -95,6 +94,7 @@ class SrvResolver(object):
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
@ -124,7 +124,7 @@ class SrvResolver(object):
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
self._dns_client.lookupService(service_name)
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
@ -136,17 +136,18 @@ class SrvResolver(object):
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
"Failed to resolve %r, falling back to cache. %r", service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
if (
len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
@ -157,13 +158,15 @@ class SrvResolver(object):
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
))
servers.append(
Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
)
)
self._cache[service_name] = list(servers)
defer.returnValue(servers)

View file

@ -54,10 +54,12 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
"", ["method"])
incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
"", ["method", "code"])
outgoing_requests_counter = Counter(
"synapse_http_matrixfederationclient_requests", "", ["method"]
)
incoming_responses_counter = Counter(
"synapse_http_matrixfederationclient_responses", "", ["method", "code"]
)
MAX_LONG_RETRIES = 10
@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
check_content_type_is_json(response.headers)
d = treq.json_content(response)
d = timeout_deferred(
d,
timeout=timeout_sec,
reactor=reactor,
)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
except Exception as e:
@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
request.txn_id,
request.destination,
response.code,
response.phrase.decode('ascii', errors='replace'),
response.phrase.decode("ascii", errors="replace"),
)
defer.returnValue(body)
@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object):
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
real_reactor, None, hs.config.federation_ip_range_blacklist,
real_reactor, None, hs.config.federation_ip_range_blacklist
)
@implementer(IReactorPluggableNameResolver)
@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object):
self.reactor = Reactor()
self.agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
)
self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent, self.reactor,
self.agent,
self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string_bytes = hs.version_string.encode('ascii')
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
def schedule(x):
@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def _send_request_with_optional_trailing_slash(
self,
request,
try_trailing_slash_on_400=False,
**send_request_args
self, request, try_trailing_slash_on_400=False, **send_request_args
):
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object):
Deferred[Dict]: Parsed JSON response body.
"""
try:
response = yield self._send_request(
request, **send_request_args
)
response = yield self._send_request(request, **send_request_args)
except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash
@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash")
request.path += "/"
response = yield self._send_request(
request, **send_request_args
)
response = yield self._send_request(request, **send_request_args)
defer.returnValue(response)
@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
if (
self.hs.config.federation_domain_whitelist is not None and
request.destination not in self.hs.config.federation_domain_whitelist
self.hs.config.federation_domain_whitelist is not None
and request.destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object):
else:
query_bytes = b""
headers_dict = {
b"User-Agent": [self.version_string_bytes],
}
headers_dict = {b"User-Agent": [self.version_string_bytes]}
with limiter:
# XXX: Would be much nicer to retry only at the transaction-layer
@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
url_bytes = urllib.parse.urlunparse((
b"matrix", destination_bytes,
path_bytes, None, query_bytes, b"",
))
url_str = url_bytes.decode('ascii')
url_bytes = urllib.parse.urlunparse(
(b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
)
url_str = url_bytes.decode("ascii")
url_to_sign_bytes = urllib.parse.urlunparse((
b"", b"",
path_bytes, None, query_bytes, b"",
))
url_to_sign_bytes = urllib.parse.urlunparse(
(b"", b"", path_bytes, None, query_bytes, b"")
)
while True:
try:
@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object):
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
json,
destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data),
cooperator=self._cooperator,
BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
destination_bytes, method_bytes, url_to_sign_bytes
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
url_str, _sec_timeout,
request.txn_id,
request.destination,
request.method,
url_str,
_sec_timeout,
)
try:
@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
response.phrase.decode('ascii', errors='replace'),
response.phrase.decode("ascii", errors="replace"),
)
if 200 <= response.code < 300:
@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object):
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
d,
timeout=_sec_timeout,
reactor=self.reactor,
d, timeout=_sec_timeout, reactor=self.reactor
)
try:
@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object):
)
body = None
e = HttpResponseException(
response.code, response.phrase, body
)
e = HttpResponseException(response.code, response.phrase, body)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None,
self, destination, method, url_bytes, content=None, destination_is=None
):
"""
Builds the Authorization headers for a federation request
@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
}
request = {"method": method, "uri": url_bytes, "origin": self.server_name}
if destination is not None:
request["destination"] = destination
@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)).encode('ascii')
auth_headers.append(
(
'X-Matrix origin=%s,key="%s",sig="%s"'
% (self.server_name, key, sig)
).encode("ascii")
)
return auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, args={}, data={},
json_data_callback=None,
long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False,
try_trailing_slash_on_400=False):
def put_json(
self,
destination,
path,
args={},
data={},
json_data_callback=None,
long_retries=False,
timeout=None,
ignore_backoff=False,
backoff_on_404=False,
try_trailing_slash_on_400=False,
):
""" Sends the specifed json data using PUT
Args:
@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.reactor, self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
timeout=None, ignore_backoff=False, args={}):
def post_json(
self,
destination,
path,
data={},
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
""" Sends the specifed json data using POST
Args:
@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object):
"""
request = MatrixFederationRequest(
method="POST",
destination=destination,
path=path,
query=args,
json=data,
method="POST", destination=destination, path=path, query=args, json=data
)
response = yield self._send_request(
@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
self.reactor, _sec_timeout, request, response,
self.reactor, _sec_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
timeout=None, ignore_backoff=False,
try_trailing_slash_on_400=False):
def get_json(
self,
destination,
path,
args=None,
retry_on_dns_fail=True,
timeout=None,
ignore_backoff=False,
try_trailing_slash_on_400=False,
):
""" GETs some json from the given host homeserver and path
Args:
@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="GET",
destination=destination,
path=path,
query=args,
method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request_with_optional_trailing_slash(
@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.reactor, self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False,
timeout=None, ignore_backoff=False, args={}):
def delete_json(
self,
destination,
path,
long_retries=False,
timeout=None,
ignore_backoff=False,
args={},
):
"""Send a DELETE request to the remote expecting some json response
Args:
@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="DELETE",
destination=destination,
path=path,
query=args,
method="DELETE", destination=destination, path=path, query=args
)
response = yield self._send_request(
@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.reactor, self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
def get_file(
self,
destination,
path,
output_stream,
args={},
retry_on_dns_fail=True,
max_size=None,
ignore_backoff=False,
):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="GET",
destination=destination,
path=path,
query=args,
method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
)
headers = dict(response.headers.getAllRawHeaders())
@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
response.phrase.decode('ascii', errors='replace'),
response.phrase.decode("ascii", errors="replace"),
length,
)
defer.returnValue((length, headers))
@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
))
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()
@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
_flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
@ -943,16 +944,15 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RequestSendFailed(RuntimeError(
"No Content-Type header"
), can_retry=False)
raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
c_type = c_type[0].decode('ascii') # only the first header
c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
), can_retry=False)
raise RequestSendFailed(
RuntimeError("Content-Type not application/json: was '%s'" % c_type),
can_retry=False,
)
def encode_query_args(args):
@ -967,4 +967,4 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes.encode('utf8')
return query_bytes.encode("utf8")

View file

@ -81,9 +81,7 @@ def wrap_json_request_handler(h):
yield h(self, request)
except SynapseError as e:
code = e.code
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@ -96,7 +94,10 @@ def wrap_json_request_handler(h):
pass
else:
respond_with_json(
request, code, e.error_dict(), send_cors=True,
request,
code,
e.error_dict(),
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@ -124,10 +125,7 @@ def wrap_json_request_handler(h):
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@ -143,6 +141,7 @@ def wrap_html_request_handler(h):
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
@ -164,9 +163,7 @@ def _return_html_error(f, request):
msg = cme.msg
if isinstance(cme, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, msg
)
logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
"Failed handle request %r",
@ -183,9 +180,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
body = HTML_ERROR_TEMPLATE.format(
code=code, msg=cgi.escape(msg),
).encode("utf-8")
body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
@ -205,6 +200,7 @@ def wrap_async_request_handler(h):
The handler may return a deferred, in which case the completion of the request isn't
logged until the deferred completes.
"""
@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
with request.processing():
@ -306,12 +302,14 @@ class JsonResource(HttpServer, resource.Resource):
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
kwargs = intern_dict({
name: _unquote(value) if value else value
for name, value in group_dict.items()
})
kwargs = intern_dict(
{
name: _unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
@ -339,7 +337,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path.decode('ascii'))
m = path_entry.pattern.match(request.path.decode("ascii"))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
@ -347,11 +345,14 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object,
response_code_message=None):
def _send_response(
self, request, code, response_json_object, response_code_message=None
):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
request,
code,
response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
@ -395,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path
def render_GET(self, request):
return redirectTo(self.url.encode('ascii'), request)
return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request):
if len(name) == 0:
@ -403,16 +404,22 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
canonical_json=True):
def respond_with_json(
request,
code,
json_object,
send_cors=False,
response_code_message=None,
pretty_print=False,
canonical_json=True,
):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
"Not sending response to request %s, already disconnected.", request
)
return
if pretty_print:
@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False,
json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes(
request, code, json_bytes,
request,
code,
json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
response_code_message=None):
def respond_with_json_bytes(
request, code, json_bytes, send_cors=False, response_code_message=None
):
"""Sends encoded JSON in response to the given request.
Args:
@ -474,7 +484,7 @@ def set_cors_headers(request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
)
@ -498,9 +508,7 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[]
)
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True

View file

@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
name = name.encode("ascii")
if name in args:
try:
@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
name = name.encode("ascii")
if name in args:
try:
return {
b"true": True,
b"false": False,
}[args[name][0]]
return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
"Boolean query parameter %r must be one of" " ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return default
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string", encoding='ascii'):
def parse_string(
request,
name,
default=None,
required=False,
allowed_values=None,
param_type="string",
encoding="ascii",
):
"""
Parse a string parameter from the request query string.
@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False,
)
def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string", encoding='ascii'):
def parse_string_from_args(
args,
name,
default=None,
required=False,
allowed_values=None,
param_type="string",
encoding="ascii",
):
if not isinstance(name, bytes):
name = name.encode('ascii')
name = name.encode("ascii")
if name in args:
value = args[name][0]
@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False,
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
name,
", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try:
content_unicode = content_bytes.decode('utf8')
content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
content = parse_json_value_from_request(
request, allow_empty_body=allow_empty_body,
)
content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
if allow_empty_body and content is None:
return {}

View file

@ -46,10 +46,11 @@ class SynapseRequest(Request):
Attributes:
logcontext(LoggingContext) : the log context for this request
"""
def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = site
self._channel = channel # this is used by the tests
self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
@ -72,12 +73,12 @@ class SynapseRequest(Request):
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
id(self),
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode('ascii', errors='replace'),
self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
)
@ -87,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
uri = self.uri.decode('ascii')
uri = self.uri.decode("ascii")
return redact_uri(uri)
def get_method(self):
@ -102,7 +103,7 @@ class SynapseRequest(Request):
"""
method = self.method
if isinstance(method, bytes):
method = self.method.decode('ascii')
method = self.method.decode("ascii")
return method
def get_user_agent(self):
@ -134,8 +135,7 @@ class SynapseRequest(Request):
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
requests_counter.labels(self.get_method(),
self.request_metrics.name).inc()
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
@ -200,7 +200,7 @@ class SynapseRequest(Request):
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
"Error processing request %r: %s %s", self, reason.type, reason.value,
"Error processing request %r: %s %s", self, reason.type, reason.value
)
if not self._is_processing:
@ -222,7 +222,7 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
self.start_time, name=servlet_name, method=self.get_method(),
self.start_time, name=servlet_name, method=self.get_method()
)
self.site.access_logger.info(
@ -230,7 +230,7 @@ class SynapseRequest(Request):
self.getClientIP(),
self.site.site_tag,
self.get_method(),
self.get_redacted_uri()
self.get_redacted_uri(),
)
def _finished_processing(self):
@ -282,7 +282,7 @@ class SynapseRequest(Request):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
@ -297,7 +297,7 @@ class SynapseRequest(Request):
code,
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode('ascii', errors='replace'),
self.clientproto.decode("ascii", errors="replace"),
user_agent,
usage.evt_db_fetch_count,
)
@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest):
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
return (
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
.split(b",")[0]
.strip()
.decode("ascii")
)
class SynapseRequestFactory(object):
@ -343,8 +348,17 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource,
server_version_string, *args, **kwargs):
def __init__(
self,
logger_name,
site_tag,
config,
resource,
server_version_string,
*args,
**kwargs
):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@ -352,7 +366,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode('ascii')
self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
pass