Do an AAAA lookup on SRV record targets (#2462)

Support SRV records which point at AAAA records, as well as A records.

Fixes https://github.com/matrix-org/synapse/issues/2405
This commit is contained in:
Richard van der Hoff 2017-09-22 20:26:47 +01:00 committed by GitHub
parent f496399ac4
commit f65e31d22f
2 changed files with 118 additions and 24 deletions

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port expires" "_Server", "priority weight host port expires"
) )
@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
return self.default_server return self.default_server
else: else:
raise ConnectError( raise ConnectError(
"Not server available for %s" % self.service_name "No server available for %s" % self.service_name
) )
# look for all servers with the same priority
min_priority = self.servers[0].priority min_priority = self.servers[0].priority
weight_indexes = list( weight_indexes = list(
(index, server.weight + 1) (index, server.weight + 1)
@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes) total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight) target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes: for index, weight in weight_indexes:
target_weight -= weight target_weight -= weight
if target_weight <= 0: if target_weight <= 0:
server = self.servers[index] server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index] del self.servers[index]
self.used_servers.append(server) self.used_servers.append(server)
return server return server
@ -280,18 +296,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try: hosts = yield _get_hosts_for_srv_record(
answers, _, _ = yield dns_client.lookupAddress(host) dns_client, str(payload.target)
except DNSNameError: )
continue
for answer in answers: for (ip, ttl) in hosts:
if answer.type == dns.A and answer.payload: host_ttl = min(answer.ttl, ttl)
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
servers.append(_Server( servers.append(_Server(
host=ip, host=ip,
@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e raise e
defer.returnValue(servers) defer.returnValue(servers)
@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []
def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.
return res[0]
def eb(res):
res.trap(DNSNameError)
return []
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
for result in results:
for answer in result:
if not answer.payload:
continue
try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue
# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)

View File

@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
from tests.utils import MockClock from tests.utils import MockClock
@unittest.DEBUG
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve(self): def test_resolve(self):
dns_client_mock = Mock() dns_client_mock = Mock()
service_name = "test_service.examle.com" service_name = "test_service.example.com"
host_name = "example.com" host_name = "example.com"
ip_address = "127.0.0.1" ip_address = "127.0.0.1"
ip6_address = "::1"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, type=dns.SRV,
@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
) )
) )
dns_client_mock.lookupService.return_value = ([answer_srv], None, None) answer_aaaa = dns.RRHeader(
dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)
dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None),
)
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)
cache = {} cache = {}
@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name) dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, ip_address)
self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self): def test_from_cache_expired_and_dns_fail(self):