mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-03 02:34:50 -04:00
Refactor and bugfix for resove_service (#4427)
This commit is contained in:
parent
23b0813599
commit
33a55289cb
6 changed files with 250 additions and 86 deletions
|
@ -12,30 +12,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import client, dns
|
||||
from twisted.names.error import DNSNameError, DomainError
|
||||
|
||||
from synapse.http.federation.srv_resolver import Server, resolve_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_CACHE = {}
|
||||
|
||||
# our record of an individual server which can be tried to reach a destination.
|
||||
#
|
||||
# "host" is the hostname acquired from the SRV record. Except when there's
|
||||
# no SRV record, in which case it is the original hostname.
|
||||
_Server = collections.namedtuple(
|
||||
"_Server", "priority weight host port expires"
|
||||
)
|
||||
|
||||
|
||||
def parse_server_name(server_name):
|
||||
"""Split a server name into host/port parts.
|
||||
|
@ -165,12 +153,9 @@ class SRVClientEndpoint(object):
|
|||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
||||
|
||||
if default_port is not None:
|
||||
self.default_server = _Server(
|
||||
self.default_server = Server(
|
||||
host=domain,
|
||||
port=default_port,
|
||||
priority=0,
|
||||
weight=0,
|
||||
expires=0,
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
@ -240,57 +225,3 @@ class SRVClientEndpoint(object):
|
|||
)
|
||||
connection = yield endpoint.connect(protocolFactory)
|
||||
defer.returnValue(connection)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
servers = []
|
||||
|
||||
try:
|
||||
try:
|
||||
answers, _, _ = yield dns_client.lookupService(service_name)
|
||||
except DNSNameError:
|
||||
defer.returnValue([])
|
||||
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b'.')):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
|
||||
payload = answer.payload
|
||||
|
||||
servers.append(_Server(
|
||||
host=str(payload.target),
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight),
|
||||
expires=int(clock.time()) + answer.ttl,
|
||||
))
|
||||
|
||||
servers.sort()
|
||||
cache[service_name] = list(servers)
|
||||
except DomainError as e:
|
||||
# We failed to resolve the name (other than a NameError)
|
||||
# Try something in the cache, else rereaise
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
logger.warn(
|
||||
"Failed to resolve %r, falling back to cache. %r",
|
||||
service_name, e
|
||||
)
|
||||
servers = list(cache_entry)
|
||||
else:
|
||||
raise e
|
||||
|
||||
defer.returnValue(servers)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue