mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-20 08:31:33 -05:00
Read from DNS cache if within TTL
This commit is contained in:
parent
a68c1b15aa
commit
f699b8f997
@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -31,7 +32,7 @@ SERVER_CACHE = {}
|
|||||||
|
|
||||||
|
|
||||||
_Server = collections.namedtuple(
|
_Server = collections.namedtuple(
|
||||||
"_Server", "priority weight host port"
|
"_Server", "priority weight host port expires"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
|
|||||||
host=domain,
|
host=domain,
|
||||||
port=default_port,
|
port=default_port,
|
||||||
priority=0,
|
priority=0,
|
||||||
weight=0
|
weight=0,
|
||||||
|
expires=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.default_server = None
|
self.default_server = None
|
||||||
@ -154,6 +156,12 @@ class SRVClientEndpoint(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
||||||
|
cache_entry = cache.get(service_name, None)
|
||||||
|
if cache_entry:
|
||||||
|
if all(s.expires > int(time.time()) for s in cache_entry):
|
||||||
|
servers = list(cache_entry)
|
||||||
|
defer.returnValue(servers)
|
||||||
|
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -173,26 +181,25 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
payload = answer.payload
|
payload = answer.payload
|
||||||
|
|
||||||
host = str(payload.target)
|
host = str(payload.target)
|
||||||
|
srv_ttl = answer.ttl
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answers, _, _ = yield dns_client.lookupAddress(host)
|
answers, _, _ = yield dns_client.lookupAddress(host)
|
||||||
except DNSNameError:
|
except DNSNameError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ips = [
|
for answer in answers:
|
||||||
answer.payload.dottedQuad()
|
if answer.type == dns.A and answer.payload:
|
||||||
for answer in answers
|
ip = answer.payload.dottedQuad()
|
||||||
if answer.type == dns.A and answer.payload
|
host_ttl = min(srv_ttl, answer.ttl)
|
||||||
]
|
|
||||||
|
|
||||||
for ip in ips:
|
|
||||||
servers.append(_Server(
|
servers.append(_Server(
|
||||||
host=ip,
|
host=ip,
|
||||||
port=int(payload.port),
|
port=int(payload.port),
|
||||||
priority=int(payload.priority),
|
priority=int(payload.priority),
|
||||||
weight=int(payload.weight)
|
weight=int(payload.weight),
|
||||||
|
expires=int(time.time()) + host_ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
servers.sort()
|
servers.sort()
|
||||||
|
@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
service_name = "test_service.examle.com"
|
service_name = "test_service.examle.com"
|
||||||
|
|
||||||
|
entry = Mock(spec_set=["expires"])
|
||||||
|
entry.expires = 999999999
|
||||||
|
|
||||||
cache = {
|
cache = {
|
||||||
service_name: [object()]
|
service_name: [entry]
|
||||||
}
|
}
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolve_service(
|
||||||
|
Loading…
Reference in New Issue
Block a user