Cache dns lookups, and use the cache if we fail to lookup servers later

This commit is contained in:
Erik Johnston 2016-01-20 11:34:09 +00:00
parent af30140621
commit 191070123d
2 changed files with 186 additions and 30 deletions

View file

@ -17,7 +17,7 @@ from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError
from twisted.names.error import DNSNameError, DomainError
import collections
import logging
@ -27,6 +27,14 @@ import random
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
_Server = collections.namedtuple(
"_Server", "priority weight host port"
)
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@ -73,10 +81,6 @@ class SRVClientEndpoint(object):
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
_Server = collections.namedtuple(
"_Server", "priority weight host port"
)
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=TCP4ClientEndpoint,
endpoint_kw_args={}):
@ -101,32 +105,8 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
def fetch_servers(self):
try:
answers, auth, add = yield client.lookupService(self.service_name)
except DNSNameError:
answers = []
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", self.service_name)
self.servers = []
self.used_servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
self.servers.append(self._Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
self.servers.sort()
self.servers = yield resolve_service(self.service_name)
def pick_server(self):
if not self.servers:
@ -170,3 +150,64 @@ class SRVClientEndpoint(object):
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
host = str(payload.target)
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]
for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)