Convert the federation agent and related code to async/await. (#7874)

This commit is contained in:
Patrick Cloke 2020-07-23 07:05:57 -04:00 committed by GitHub
parent 13d77464c9
commit 68cd935826
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 53 deletions

1
changelog.d/7874.misc Normal file
View File

@ -0,0 +1 @@
Convert the federation agent and related code to async/await.

View File

@ -15,6 +15,7 @@
import logging import logging
import urllib import urllib
from typing import List
from netaddr import AddrFormatError, IPAddress from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer from zope.interface import implementer
@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object):
return run_in_background(self._do_connect, protocol_factory) return run_in_background(self._do_connect, protocol_factory)
@defer.inlineCallbacks async def _do_connect(self, protocol_factory):
def _do_connect(self, protocol_factory):
first_exception = None first_exception = None
server_list = yield self._resolve_server() server_list = await self._resolve_server()
for server in server_list: for server in server_list:
host = server.host host = server.host
@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object):
endpoint = HostnameEndpoint(self._reactor, host, port) endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options: if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint) endpoint = wrapClientTLS(self._tls_options, endpoint)
result = yield make_deferred_yieldable( result = await make_deferred_yieldable(
endpoint.connect(protocol_factory) endpoint.connect(protocol_factory)
) )
@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object):
# to try and if that doesn't work then we'll have an exception. # to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
@defer.inlineCallbacks async def _resolve_server(self) -> List[Server]:
def _resolve_server(self):
"""Resolves the server name to a list of hosts and ports to attempt to """Resolves the server name to a list of hosts and ports to attempt to
connect to. connect to.
Returns:
Deferred[list[Server]]
""" """
if self._parsed_uri.scheme != b"matrix": if self._parsed_uri.scheme != b"matrix":
@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object):
if port or _is_ip_literal(host): if port or _is_ip_literal(host):
return [Server(host, port or 8448)] return [Server(host, port or 8448)]
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list: if server_list:
return server_list return server_list

View File

@ -17,10 +17,10 @@
import logging import logging
import random import random
import time import time
from typing import List
import attr import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import client, dns from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError from twisted.names.error import DNSNameError, DomainError
@ -113,16 +113,14 @@ class SrvResolver(object):
self._cache = cache self._cache = cache
self._get_time = get_time self._get_time = get_time
@defer.inlineCallbacks async def resolve_service(self, service_name: bytes) -> List[Server]:
def resolve_service(self, service_name):
"""Look up a SRV record """Look up a SRV record
Args: Args:
service_name (bytes): record to look up service_name (bytes): record to look up
Returns: Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
a list of the SRV records, or an empty list if none found
""" """
now = int(self._get_time()) now = int(self._get_time())
@ -136,7 +134,7 @@ class SrvResolver(object):
return _sort_server_list(servers) return _sort_server_list(servers)
try: try:
answers, _, _ = yield make_deferred_yieldable( answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name) self._dns_client.lookupService(service_name)
) )
except DNSNameError: except DNSNameError:

View File

@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory return test_server_connection_factory
# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(result):
async def resolve_service(_):
return result
return resolve_service
class MatrixFederationAgentTests(unittest.TestCase): class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.reactor = ThreadedMemoryReactorClock() self.reactor = ThreadedMemoryReactorClock()
@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when the certificate on the server doesn't match the hostname Test the behaviour when the certificate on the server doesn't match the hostname
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4" self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar") test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known Test the behaviour when the server name has no port, no SRV, and no well-known
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar") test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere """Test the behaviour when the .well-known delegates elsewhere
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f" self.reactor.lookups["target-server"] = "1::f"
@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but """Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect the .well-known has a 300 redirect
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f" self.reactor.lookups["target-server"] = "1::f"
@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV) Test the behaviour when the server name has an *invalid* well-known (and no SRV)
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar") test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the # the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA) # presented cert is signed by a test CA)
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True) config = default_config("test", parse=True)
@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
Test the behaviour when there is a single SRV record Test the behaviour when there is a single SRV record
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
Server(host=b"srvtarget", port=8443) [Server(host=b"srvtarget", port=8443)]
] )
self.reactor.lookups["srvtarget"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar") test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4") self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443) self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = lambda _: [ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
Server(host=b"srvtarget", port=8443) [Server(host=b"srvtarget", port=8443)]
] )
self._handle_well_known_connection( self._handle_well_known_connection(
client_factory, client_factory,
@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self): def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in""" """test the behaviour when the server name has idna chars in"""
self.mock_resolver.resolve_service.side_effect = lambda _: [] self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string. # the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self): def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars""" """test the behaviour when the target of a SRV record has idna chars"""
self.mock_resolver.resolve_service.side_effect = lambda _: [ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
Server(host=b"xn--trget-3qa.com", port=8443) # târget.com [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
] )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self): def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails. """Test that other SRV results are tried if the first one fails.
""" """
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
self.mock_resolver.resolve_service.side_effect = lambda _: [ [
Server(host=b"target.com", port=8443), Server(host=b"target.com", port=8443),
Server(host=b"target.com", port=8444), Server(host=b"target.com", port=8444),
] ]
)
self.reactor.lookups["target.com"] = "1.2.3.4" self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar") test_d = self._make_get_request(b"matrix://testserv/foo/bar")

View File

@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver from synapse.http.federation.srv_resolver import SrvResolver
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.logging.context import LoggingContext, current_context
from tests import unittest from tests import unittest
from tests.utils import MockClock from tests.utils import MockClock
@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx: with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name) resolve_d = resolver.resolve_service(service_name)
result = yield defer.ensureDeferred(resolve_d)
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d
# should have restored our context # should have restored our context
self.assertIs(current_context(), ctx) self.assertIs(current_context(), ctx)
@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]} cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolver.resolve_service(service_name) servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time dns_client=dns_client_mock, cache=cache, get_time=clock.time
) )
servers = yield resolver.resolve_service(service_name) servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called) self.assertFalse(dns_client_mock.lookupService.called)
@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError): with self.assertRaises(error.DNSServerError):
yield resolver.resolve_service(service_name) yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_name_error(self): def test_name_error(self):
@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolver.resolve_service(service_name) servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0) self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolver.resolve_service(service_name) # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
self.assertNoResult(resolve_d) resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError # returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback( lookup_deferred.callback(
@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {} cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolver.resolve_service(service_name) # Old versions of Twisted don't have an ensureDeferred in successResultOf.
self.assertNoResult(resolve_d) resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback( lookup_deferred.callback(
( (