put resolve_service in an object

this makes it easier to stub things out for tests.
This commit is contained in:
Richard van der Hoff 2019-01-22 17:42:26 +00:00
parent 53a327b4d5
commit 7021784d46
3 changed files with 96 additions and 75 deletions

View file

@ -84,73 +84,86 @@ def pick_server_from_list(server_list):
)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
"""Look up a SRV record, with caching
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
service_name (bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
clock (object): clock implementation. must provide a time() method.
Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
get_time (callable): clock implementation. Should return seconds since the epoch
"""
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
@defer.inlineCallbacks
def resolve_service(self, service_name):
"""Look up a SRV record
try:
answers, _, _ = yield make_deferred_yieldable(
dns_client.lookupService(service_name),
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
Args:
service_name (bytes): record to look up
Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
)
defer.returnValue(list(cache_entry))
else:
raise e
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=int(clock.time()) + answer.ttl,
))
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
))
cache[service_name] = list(servers)
defer.returnValue(servers)
self._cache[service_name] = list(servers)
defer.returnValue(servers)