Refactor and bugfix for resove_service (#4427)

This commit is contained in:
Richard van der Hoff 2019-01-22 10:59:27 +00:00 committed by GitHub
parent 23b0813599
commit 33a55289cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 86 deletions

View file

@ -12,30 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import random
import re
import time
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
from synapse.http.federation.srv_resolver import Server, resolve_service
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
@ -165,12 +153,9 @@ class SRVClientEndpoint(object):
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
self.default_server = _Server(
self.default_server = Server(
host=domain,
port=default_port,
priority=0,
weight=0,
expires=0,
)
else:
self.default_server = None
@ -240,57 +225,3 @@ class SRVClientEndpoint(object):
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(_Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + answer.ttl,
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)