put resolve_service in an object

this makes it easier to stub things out for tests.
This commit is contained in:
Richard van der Hoff 2019-01-22 17:42:26 +00:00
parent 53a327b4d5
commit 7021784d46
3 changed files with 96 additions and 75 deletions

View File

@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
Args: Args:
reactor (IReactor): twisted reactor to use for underlying requests reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None): tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS. factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
""" """
def __init__(self, reactor, tls_client_options_factory): def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None,
):
self._reactor = reactor self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor) self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False self._pool.retryAutomatically = False
@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
if port is not None: if port is not None:
target = (host, port) target = (host, port)
else: else:
server_list = yield resolve_service(server_name_bytes) server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
if not server_list: if not server_list:
target = (host, 8448) target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target) logger.debug("No SRV record for %s, using %s", host, target)

View File

@ -84,73 +84,86 @@ def pick_server_from_list(server_list):
) )
@defer.inlineCallbacks class SrvResolver(object):
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): """Interface to the dns client to do SRV lookups, with result caching.
"""Look up a SRV record, with caching
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here. but the cache never gets populated), so we add our own caching layer here.
Args: Args:
service_name (bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object cache (dict): cache object
clock (object): clock implementation. must provide a time() method. get_time (callable): clock implementation. Should return seconds since the epoch
Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
""" """
if not isinstance(service_name, bytes): def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
raise TypeError("%r is not a byte string" % (service_name,)) self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
cache_entry = cache.get(service_name, None) @defer.inlineCallbacks
if cache_entry: def resolve_service(self, service_name):
if all(s.expires > int(clock.time()) for s in cache_entry): """Look up a SRV record
servers = list(cache_entry)
defer.returnValue(servers)
try: Args:
answers, _, _ = yield make_deferred_yieldable( service_name (bytes): record to look up
dns_client.lookupService(service_name),
) Returns:
except DNSNameError: Deferred[list[Server]]:
# TODO: cache this. We can get the SOA out of the exception, and use a list of the SRV records, or an empty list if none found
# the negative-TTL value. """
defer.returnValue([]) now = int(self._get_time())
except DomainError as e:
# We failed to resolve the name (other than a NameError) if not isinstance(service_name, bytes):
# Try something in the cache, else rereaise raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = cache.get(service_name, None)
cache_entry = self._cache.get(service_name, None)
if cache_entry: if cache_entry:
logger.warn( if all(s.expires > now for s in cache_entry):
"Failed to resolve %r, falling back to cache. %r", servers = list(cache_entry)
service_name, e defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
) )
defer.returnValue(list(cache_entry)) except DNSNameError:
else: # TODO: cache this. We can get the SOA out of the exception, and use
raise e # the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1 if (len(answers) == 1
and answers[0].type == dns.SRV and answers[0].type == dns.SRV
and answers[0].payload and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')): and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name) raise ConnectError("Service %s unavailable" % service_name)
servers = [] servers = []
for answer in answers: for answer in answers:
if answer.type != dns.SRV or not answer.payload: if answer.type != dns.SRV or not answer.payload:
continue continue
payload = answer.payload payload = answer.payload
servers.append(Server( servers.append(Server(
host=payload.target.name, host=payload.target.name,
port=payload.port, port=payload.port,
priority=payload.priority, priority=payload.priority,
weight=payload.weight, weight=payload.weight,
expires=int(clock.time()) + answer.ttl, expires=now + answer.ttl,
)) ))
cache[service_name] = list(servers) self._cache[service_name] = list(servers)
defer.returnValue(servers) defer.returnValue(servers)

View File

@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import dns, error from twisted.names import dns, error
from synapse.http.federation.srv_resolver import resolve_service from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest from tests import unittest
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = result_deferred dns_client_mock.lookupService.return_value = result_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_lookup(): def do_lookup():
with LoggingContext("one") as ctx: with LoggingContext("one") as ctx:
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
@ -89,10 +89,9 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 0 entry.expires = 0
cache = {service_name: [entry]} cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service( servers = yield resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -112,11 +111,12 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 999999999 entry.expires = 999999999
cache = {service_name: [entry]} cache = {service_name: [entry]}
resolver = SrvResolver(
servers = yield resolve_service( dns_client=dns_client_mock, cache=cache, get_time=clock.time,
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
) )
servers = yield resolver.resolve_service(service_name)
self.assertFalse(dns_client_mock.lookupService.called) self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 1)
@ -131,9 +131,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError): with self.assertRaises(error.DNSServerError):
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) yield resolver.resolve_service(service_name)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_name_error(self): def test_name_error(self):
@ -144,10 +145,9 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service( servers = yield resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertEquals(len(servers), 0) self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred dns_client_mock.lookupService.return_value = lookup_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError # returning a single "." should make the lookup fail with a ConenctError
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred dns_client_mock.lookupService.return_value = lookup_deferred
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service( resolve_d = resolver.resolve_service(service_name)
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
lookup_deferred.callback(( lookup_deferred.callback((