MatrixFederationAgent: factor out routing logic

This is going to get too big and unmanageable.
This commit is contained in:
Richard van der Hoff 2019-01-27 23:24:17 +00:00
parent d840019192
commit 51958df766

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
@ -85,9 +86,11 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request response from being received (including problems that prevent the request
from being sent). from being sent).
""" """
parsed_uri = URI.fromBytes(uri, defaultPort=-1) parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
# set up the TLS connection params
#
# XXX disabling TLS is really only supported here for the benefit of the # XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make # unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests. # the code support the unit tests.
@ -95,22 +98,9 @@ class MatrixFederationAgent(object):
tls_options = None tls_options = None
else: else:
tls_options = self._tls_client_options_factory.get_options( tls_options = self._tls_client_options_factory.get_options(
parsed_uri.host.decode("ascii") res.tls_server_name.decode("ascii")
) )
if parsed_uri.port != -1:
# there was an explicit port in the URI
target = parsed_uri.host, parsed_uri.port
else:
service_name = b"_matrix._tcp.%s" % (parsed_uri.host, )
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target = (parsed_uri.host, 8448)
logger.debug(
"No SRV record for %s, using %s", service_name, target)
else:
target = pick_server_from_list(server_list)
# make sure that the Host header is set correctly # make sure that the Host header is set correctly
if headers is None: if headers is None:
headers = Headers() headers = Headers()
@ -118,13 +108,13 @@ class MatrixFederationAgent(object):
headers = headers.copy() headers = headers.copy()
if not headers.hasHeader(b'host'): if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', parsed_uri.netloc) headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object): class EndpointFactory(object):
@staticmethod @staticmethod
def endpointForURI(_uri): def endpointForURI(_uri):
logger.info("Connecting to %s:%s", target[0], target[1]) logger.info("Connecting to %s:%s", res.target_host, res.target_port)
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None: if tls_options is not None:
ep = wrapClientTLS(tls_options, ep) ep = wrapClientTLS(tls_options, ep)
return ep return ep
@ -134,3 +124,57 @@ class MatrixFederationAgent(object):
agent.request(method, uri, headers, bodyProducer) agent.request(method, uri, headers, bodyProducer)
) )
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri):
"""Helper for `request`: determine the routing for a Matrix URI
Args:
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given.
Returns:
Deferred[_RoutingResult]
"""
if parsed_uri.port != -1:
# there is an explicit port
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
))
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target_host = parsed_uri.host
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
)
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
))
@attr.s
class _RoutingResult(object):
host_header = attr.ib()
tls_server_name = attr.ib()
target_host = attr.ib()
target_port = attr.ib()