mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-17 08:24:20 -05:00
Convert the federation agent and related code to async/await. (#7874)
This commit is contained in:
parent
13d77464c9
commit
68cd935826
1
changelog.d/7874.misc
Normal file
1
changelog.d/7874.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert the federation agent and related code to async/await.
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from netaddr import AddrFormatError, IPAddress
|
from netaddr import AddrFormatError, IPAddress
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
@ -236,11 +237,10 @@ class MatrixHostnameEndpoint(object):
|
|||||||
|
|
||||||
return run_in_background(self._do_connect, protocol_factory)
|
return run_in_background(self._do_connect, protocol_factory)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _do_connect(self, protocol_factory):
|
||||||
def _do_connect(self, protocol_factory):
|
|
||||||
first_exception = None
|
first_exception = None
|
||||||
|
|
||||||
server_list = yield self._resolve_server()
|
server_list = await self._resolve_server()
|
||||||
|
|
||||||
for server in server_list:
|
for server in server_list:
|
||||||
host = server.host
|
host = server.host
|
||||||
@ -251,7 +251,7 @@ class MatrixHostnameEndpoint(object):
|
|||||||
endpoint = HostnameEndpoint(self._reactor, host, port)
|
endpoint = HostnameEndpoint(self._reactor, host, port)
|
||||||
if self._tls_options:
|
if self._tls_options:
|
||||||
endpoint = wrapClientTLS(self._tls_options, endpoint)
|
endpoint = wrapClientTLS(self._tls_options, endpoint)
|
||||||
result = yield make_deferred_yieldable(
|
result = await make_deferred_yieldable(
|
||||||
endpoint.connect(protocol_factory)
|
endpoint.connect(protocol_factory)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -271,13 +271,9 @@ class MatrixHostnameEndpoint(object):
|
|||||||
# to try and if that doesn't work then we'll have an exception.
|
# to try and if that doesn't work then we'll have an exception.
|
||||||
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
|
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _resolve_server(self) -> List[Server]:
|
||||||
def _resolve_server(self):
|
|
||||||
"""Resolves the server name to a list of hosts and ports to attempt to
|
"""Resolves the server name to a list of hosts and ports to attempt to
|
||||||
connect to.
|
connect to.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[list[Server]]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._parsed_uri.scheme != b"matrix":
|
if self._parsed_uri.scheme != b"matrix":
|
||||||
@ -298,7 +294,7 @@ class MatrixHostnameEndpoint(object):
|
|||||||
if port or _is_ip_literal(host):
|
if port or _is_ip_literal(host):
|
||||||
return [Server(host, port or 8448)]
|
return [Server(host, port or 8448)]
|
||||||
|
|
||||||
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
|
server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
|
||||||
|
|
||||||
if server_list:
|
if server_list:
|
||||||
return server_list
|
return server_list
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.names import client, dns
|
from twisted.names import client, dns
|
||||||
from twisted.names.error import DNSNameError, DomainError
|
from twisted.names.error import DNSNameError, DomainError
|
||||||
@ -113,15 +113,13 @@ class SrvResolver(object):
|
|||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._get_time = get_time
|
self._get_time = get_time
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def resolve_service(self, service_name: bytes) -> List[Server]:
|
||||||
def resolve_service(self, service_name):
|
|
||||||
"""Look up a SRV record
|
"""Look up a SRV record
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (bytes): record to look up
|
service_name (bytes): record to look up
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[Server]]:
|
|
||||||
a list of the SRV records, or an empty list if none found
|
a list of the SRV records, or an empty list if none found
|
||||||
"""
|
"""
|
||||||
now = int(self._get_time())
|
now = int(self._get_time())
|
||||||
@ -136,7 +134,7 @@ class SrvResolver(object):
|
|||||||
return _sort_server_list(servers)
|
return _sort_server_list(servers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answers, _, _ = yield make_deferred_yieldable(
|
answers, _, _ = await make_deferred_yieldable(
|
||||||
self._dns_client.lookupService(service_name)
|
self._dns_client.lookupService(service_name)
|
||||||
)
|
)
|
||||||
except DNSNameError:
|
except DNSNameError:
|
||||||
|
@ -67,6 +67,14 @@ def get_connection_factory():
|
|||||||
return test_server_connection_factory
|
return test_server_connection_factory
|
||||||
|
|
||||||
|
|
||||||
|
# Once Async Mocks or lambdas are supported this can go away.
|
||||||
|
def generate_resolve_service(result):
|
||||||
|
async def resolve_service(_):
|
||||||
|
return result
|
||||||
|
|
||||||
|
return resolve_service
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationAgentTests(unittest.TestCase):
|
class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test the behaviour when the certificate on the server doesn't match the hostname
|
Test the behaviour when the certificate on the server doesn't match the hostname
|
||||||
"""
|
"""
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv1"] = "1.2.3.4"
|
self.reactor.lookups["testserv1"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
|
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
|
||||||
@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
Test the behaviour when the server name has no port, no SRV, and no well-known
|
Test the behaviour when the server name has no port, no SRV, and no well-known
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
"""Test the behaviour when the .well-known delegates elsewhere
|
"""Test the behaviour when the .well-known delegates elsewhere
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
self.reactor.lookups["target-server"] = "1::f"
|
self.reactor.lookups["target-server"] = "1::f"
|
||||||
|
|
||||||
@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
"""Test the behaviour when the server name has no port and no SRV record, but
|
"""Test the behaviour when the server name has no port and no SRV record, but
|
||||||
the .well-known has a 300 redirect
|
the .well-known has a 300 redirect
|
||||||
"""
|
"""
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
self.reactor.lookups["target-server"] = "1::f"
|
self.reactor.lookups["target-server"] = "1::f"
|
||||||
|
|
||||||
@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
|
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
# the config left to the default, which will not trust it (since the
|
# the config left to the default, which will not trust it (since the
|
||||||
# presented cert is signed by a test CA)
|
# presented cert is signed by a test CA)
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
config = default_config("test", parse=True)
|
config = default_config("test", parse=True)
|
||||||
@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test the behaviour when there is a single SRV record
|
Test the behaviour when there is a single SRV record
|
||||||
"""
|
"""
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
|
||||||
Server(host=b"srvtarget", port=8443)
|
[Server(host=b"srvtarget", port=8443)]
|
||||||
]
|
)
|
||||||
self.reactor.lookups["srvtarget"] = "1.2.3.4"
|
self.reactor.lookups["srvtarget"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.assertEqual(host, "1.2.3.4")
|
self.assertEqual(host, "1.2.3.4")
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
|
||||||
Server(host=b"srvtarget", port=8443)
|
[Server(host=b"srvtarget", port=8443)]
|
||||||
]
|
)
|
||||||
|
|
||||||
self._handle_well_known_connection(
|
self._handle_well_known_connection(
|
||||||
client_factory,
|
client_factory,
|
||||||
@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
def test_idna_servername(self):
|
def test_idna_servername(self):
|
||||||
"""test the behaviour when the server name has idna chars in"""
|
"""test the behaviour when the server name has idna chars in"""
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||||
|
|
||||||
# the resolver is always called with the IDNA hostname as a native string.
|
# the resolver is always called with the IDNA hostname as a native string.
|
||||||
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
|
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
|
||||||
@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
def test_idna_srv_target(self):
|
def test_idna_srv_target(self):
|
||||||
"""test the behaviour when the target of a SRV record has idna chars"""
|
"""test the behaviour when the target of a SRV record has idna chars"""
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
|
||||||
Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
|
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
|
||||||
]
|
)
|
||||||
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
|
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
|
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
|
||||||
@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
def test_srv_fallbacks(self):
|
def test_srv_fallbacks(self):
|
||||||
"""Test that other SRV results are tried if the first one fails.
|
"""Test that other SRV results are tried if the first one fails.
|
||||||
"""
|
"""
|
||||||
|
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
[
|
||||||
Server(host=b"target.com", port=8443),
|
Server(host=b"target.com", port=8443),
|
||||||
Server(host=b"target.com", port=8444),
|
Server(host=b"target.com", port=8444),
|
||||||
]
|
]
|
||||||
|
)
|
||||||
self.reactor.lookups["target.com"] = "1.2.3.4"
|
self.reactor.lookups["target.com"] = "1.2.3.4"
|
||||||
|
|
||||||
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
|
||||||
|
@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
|
|||||||
from twisted.names import dns, error
|
from twisted.names import dns, error
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import SrvResolver
|
from synapse.http.federation.srv_resolver import SrvResolver
|
||||||
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
|
from synapse.logging.context import LoggingContext, current_context
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
with LoggingContext("one") as ctx:
|
with LoggingContext("one") as ctx:
|
||||||
resolve_d = resolver.resolve_service(service_name)
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
|
result = yield defer.ensureDeferred(resolve_d)
|
||||||
self.assertNoResult(resolve_d)
|
|
||||||
|
|
||||||
# should have reset to the sentinel context
|
|
||||||
self.assertIs(current_context(), SENTINEL_CONTEXT)
|
|
||||||
|
|
||||||
result = yield resolve_d
|
|
||||||
|
|
||||||
# should have restored our context
|
# should have restored our context
|
||||||
self.assertIs(current_context(), ctx)
|
self.assertIs(current_context(), ctx)
|
||||||
@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolver.resolve_service(service_name)
|
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
dns_client=dns_client_mock, cache=cache, get_time=clock.time
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time
|
||||||
)
|
)
|
||||||
|
|
||||||
servers = yield resolver.resolve_service(service_name)
|
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
self.assertFalse(dns_client_mock.lookupService.called)
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
with self.assertRaises(error.DNSServerError):
|
with self.assertRaises(error.DNSServerError):
|
||||||
yield resolver.resolve_service(service_name)
|
yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_name_error(self):
|
def test_name_error(self):
|
||||||
@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
cache = {}
|
cache = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolver.resolve_service(service_name)
|
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
self.assertEquals(len(servers), 0)
|
self.assertEquals(len(servers), 0)
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
cache = {}
|
cache = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolver.resolve_service(service_name)
|
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
|
||||||
self.assertNoResult(resolve_d)
|
resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
# returning a single "." should make the lookup fail with a ConenctError
|
# returning a single "." should make the lookup fail with a ConenctError
|
||||||
lookup_deferred.callback(
|
lookup_deferred.callback(
|
||||||
@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
cache = {}
|
cache = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolver.resolve_service(service_name)
|
# Old versions of Twisted don't have an ensureDeferred in successResultOf.
|
||||||
self.assertNoResult(resolve_d)
|
resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
lookup_deferred.callback(
|
lookup_deferred.callback(
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user