diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 71a15f434..c20818579 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,21 +14,21 @@ # limitations under the License. import logging +import urllib -import attr -from netaddr import IPAddress +from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IAgentEndpointFactory -from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.util import Clock logger = logging.getLogger(__name__) @@ -36,8 +36,9 @@ logger = logging.getLogger(__name__) @implementer(IAgent) class MatrixFederationAgent(object): - """An Agent-like thing which provides a `request` method which will look up a matrix - server and send an HTTP request to it. + """An Agent-like thing which provides a `request` method which correctly + handles resolving matrix server names when using matrix://. Handles standard + https URIs as normal. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -65,23 +66,25 @@ class MatrixFederationAgent(object): ): self._reactor = reactor self._clock = Clock(reactor) - - self._tls_client_options_factory = tls_client_options_factory - if _srv_resolver is None: - _srv_resolver = SrvResolver() - self._srv_resolver = _srv_resolver - self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 + self._agent = Agent.usingEndpointFactory( + self._reactor, + MatrixHostnameEndpointFactory( + reactor, tls_client_options_factory, _srv_resolver + ), + pool=self._pool, + ) + self._well_known_resolver = WellKnownResolver( self._reactor, agent=Agent( self._reactor, - pool=self._pool, contextFactory=tls_client_options_factory, + pool=self._pool, ), well_known_cache=_well_known_cache, ) @@ -91,19 +94,15 @@ class MatrixFederationAgent(object): """ Args: method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): HTTP headers to send with the request, or None to send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have no body. - Returns: Deferred[twisted.web.iweb.IResponse]: fires when the header of the response has been received (regardless of the @@ -111,210 +110,195 @@ class MatrixFederationAgent(object): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + # We use urlparse as that will set `port` to None if there is no + # explicit port. + parsed_uri = urllib.parse.urlparse(uri) - # set up the TLS connection params + # If this is a matrix:// URI check if the server has delegated matrix + # traffic using well-known delegation. # - # XXX disabling TLS is really only supported here for the benefit of the - # unit tests. We should make the UTs cope with TLS rather than having to make - # the code support the unit tests. - if self._tls_client_options_factory is None: - tls_options = None - else: - tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + # We have to do this here and not in the endpoint as we need to rewrite + # the host header with the delegated server name. + delegated_server = None + if ( + parsed_uri.scheme == b"matrix" + and not _is_ip_literal(parsed_uri.hostname) + and not parsed_uri.port + ): + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.hostname ) + delegated_server = well_known_result.delegated_server - # make sure that the Host header is set correctly + if delegated_server: + # Ok, the server has delegated matrix traffic to somewhere else, so + # lets rewrite the URL to replace the server with the delegated + # server name. + uri = urllib.parse.urlunparse( + ( + parsed_uri.scheme, + delegated_server, + parsed_uri.path, + parsed_uri.params, + parsed_uri.query, + parsed_uri.fragment, + ) + ) + parsed_uri = urllib.parse.urlparse(uri) + + # We need to make sure the host header is set to the netloc of the + # server. if headers is None: headers = Headers() else: headers = headers.copy() if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) + headers.addRawHeader(b"host", parsed_uri.netloc) - class EndpointFactory(object): - @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port - ) - if tls_options is not None: - ep = wrapClientTLS(tls_options, ep) - return ep + with PreserveLoggingContext(): + res = yield self._agent.request(method, uri, headers, bodyProducer) - agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) - res = yield make_deferred_yieldable( - agent.request(method, uri, headers, bodyProducer) - ) return res - @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): - """Helper for `request`: determine the routing for a Matrix URI - Args: - parsed_uri (twisted.web.client.URI): uri to route. Note that it should be - parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 - if there is no explicit port given. +@implementer(IAgentEndpointFactory) +class MatrixHostnameEndpointFactory(object): + """Factory for MatrixHostnameEndpoint for parsing to an Agent. + """ - lookup_well_known (bool): True if we should look up the .well-known file if - there is no SRV record. + def __init__(self, reactor, tls_client_options_factory, srv_resolver): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory - Returns: - Deferred[_RoutingResult] - """ - # check for an IP literal - try: - ip_address = IPAddress(parsed_uri.host.decode("ascii")) - except Exception: - # not an IP address - ip_address = None + if srv_resolver is None: + srv_resolver = SrvResolver() - if ip_address: - port = parsed_uri.port - if port == -1: - port = 8448 - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) + self._srv_resolver = srv_resolver - if parsed_uri.port != -1: - # there is an explicit port - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) - - if lookup_well_known: - # try a .well-known lookup - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.host - ) - well_known_server = well_known_result.delegated_server - - if well_known_server: - # if we found a .well-known, start again, but don't do another - # .well-known lookup. - - # parse the server name in the .well-known response into host/port. - # (This code is lifted from twisted.web.client.URI.fromBytes). - if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) - try: - well_known_port = int(well_known_port) - except ValueError: - # the part after the colon could not be parsed as an int - # - we assume it is an IPv6 literal with no port (the closing - # ']' stops it being parsed as an int) - well_known_host, well_known_port = well_known_server, -1 - else: - well_known_host, well_known_port = well_known_server, -1 - - new_uri = URI( - scheme=parsed_uri.scheme, - netloc=well_known_server, - host=well_known_host, - port=well_known_port, - path=parsed_uri.path, - params=parsed_uri.params, - query=parsed_uri.query, - fragment=parsed_uri.fragment, - ) - - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - return res - - # try a SRV lookup - service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield self._srv_resolver.resolve_service(service_name) - - if not server_list: - target_host = parsed_uri.host - port = 8448 - logger.debug( - "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), - target_host.decode("ascii"), - port, - ) - else: - target_host, port = pick_server_from_list(server_list) - logger.debug( - "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), - port, - parsed_uri.host.decode("ascii"), - ) - - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, + def endpointForURI(self, parsed_uri): + return MatrixHostnameEndpoint( + self._reactor, + self._tls_client_options_factory, + self._srv_resolver, + parsed_uri, ) @implementer(IStreamClientEndpoint) -class LoggingHostnameEndpoint(object): - """A wrapper for HostnameEndpint which logs when it connects""" +class MatrixHostnameEndpoint(object): + """An endpoint that resolves matrix:// URLs using Matrix server name + resolution (i.e. via SRV). Does not check for well-known delegation. + """ - def __init__(self, reactor, host, port, *args, **kwargs): - self.host = host - self.port = port - self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) + def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + self._reactor = reactor + # We reparse the URI so that defaultPort is -1 rather than 80 + self._parsed_uri = parsed_uri + + # set up the TLS connection params + # + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + + if tls_client_options_factory is None: + self._tls_options = None + else: + self._tls_options = tls_client_options_factory.get_options( + self._parsed_uri.host.decode("ascii") + ) + + self._srv_resolver = srv_resolver + + @defer.inlineCallbacks def connect(self, protocol_factory): - logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) - return self.ep.connect(protocol_factory) + """Implements IStreamClientEndpoint interface + """ + + first_exception = None + + server_list = yield self._resolve_server() + + for server in server_list: + host = server.host + port = server.port + + try: + logger.info("Connecting to %s:%i", host.decode("ascii"), port) + endpoint = HostnameEndpoint(self._reactor, host, port) + if self._tls_options: + endpoint = wrapClientTLS(self._tls_options, endpoint) + result = yield make_deferred_yieldable( + endpoint.connect(protocol_factory) + ) + + return result + except Exception as e: + logger.info( + "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e + ) + if not first_exception: + first_exception = e + + # We return the first failure because that's probably the most interesting. + if first_exception: + raise first_exception + + # This shouldn't happen as we should always have at least one host/port + # to try and if that doesn't work then we'll have an exception. + raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) + + @defer.inlineCallbacks + def _resolve_server(self): + """Resolves the server name to a list of hosts and ports to attempt to + connect to. + + Returns: + Deferred[list[Server]] + """ + + if self._parsed_uri.scheme != b"matrix": + return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] + + # Note: We don't do well-known lookup as that needs to have happened + # before now, due to needing to rewrite the Host header of the HTTP + # request. + + parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) + + host = parsed_uri.hostname + port = parsed_uri.port + + # If there is an explicit port or the host is an IP address we bypass + # SRV lookups and just use the given host/port. + if port or _is_ip_literal(host): + return [Server(host, port or 8448)] + + server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + + if server_list: + return server_list + + # No SRV records, so we fallback to host and 8448 + return [Server(host, 8448)] -@attr.s -class _RoutingResult(object): - """The result returned by `_route_matrix_uri`. +def _is_ip_literal(host): + """Test if the given host name is either an IPv4 or IPv6 literal. - Contains the parameters needed to direct a federation connection to a particular - server. + Args: + host (bytes) - Where a SRV record points to several servers, this object contains a single server - chosen from the list. + Returns: + bool """ - host_header = attr.ib() - """ - The value we should assign to the Host header (host:port from the matrix - URI, or .well-known). + host = host.decode("ascii") - :type: bytes - """ - - tls_server_name = attr.ib() - """ - The server name we should set in the SNI (typically host, without port, from the - matrix URI or .well-known) - - :type: bytes - """ - - target_host = attr.ib() - """ - The hostname (or IP literal) we should route the TCP connection to (the target of the - SRV record, or the hostname from the URL/.well-known) - - :type: bytes - """ - - target_port = attr.ib() - """ - The port we should route the TCP connection to (the target of the SRV record, or - the port from the URL/.well-known, or 8448) - - :type: int - """ + try: + IPAddress(host) + return True + except AddrFormatError: + return False diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b32188766..bbda0a23f 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} -@attr.s +@attr.s(slots=True, frozen=True) class Server(object): """ Our record of an individual server which can be tried to reach a destination. @@ -83,6 +83,35 @@ def pick_server_from_list(server_list): raise RuntimeError("pick_server_from_list got to end of eligible server list.") +def _sort_server_list(server_list): + """Given a list of SRV records sort them into priority order and shuffle + each priority with the given weight. + """ + priority_map = {} + + for server in server_list: + priority_map.setdefault(server.priority, []).append(server) + + results = [] + for priority in sorted(priority_map): + servers = priority_map.pop(priority) + + while servers: + total_weight = sum(s.weight for s in servers) + target_weight = random.randint(0, total_weight) + + for s in servers: + target_weight -= s.weight + + if target_weight <= 0: + break + + results.append(s) + servers.remove(s) + + return results + + class SrvResolver(object): """Interface to the dns client to do SRV lookups, with result caching. @@ -120,7 +149,7 @@ class SrvResolver(object): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - return servers + return _sort_server_list(servers) try: answers, _, _ = yield make_deferred_yieldable( @@ -169,4 +198,4 @@ class SrvResolver(object): ) self._cache[service_name] = list(servers) - return servers + return _sort_server_list(servers) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 2c568788b..f97c8a59f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import ( from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache +from tests import unittest from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) @@ -67,7 +67,8 @@ def get_connection_factory(): return test_server_connection_factory -class MatrixFederationAgentTests(TestCase): +@unittest.DEBUG +class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -1056,8 +1057,64 @@ class MatrixFederationAgentTests(TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_srv_fallbacks(self): + """Test that other SRV results are tried if the first one fails. + """ -class TestCachePeriodFromHeaders(TestCase): + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + +class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self): # uppercase self.assertEqual( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 3b885ef64..df034ab23 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 0 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 999999999 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(