put resolve_service in an object

this makes it easier to stub things out for tests.
This commit is contained in:
Richard van der Hoff 2019-01-22 17:42:26 +00:00
parent 53a327b4d5
commit 7021784d46
3 changed files with 96 additions and 75 deletions

View file

@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import resolve_service
from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext
from tests import unittest
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = result_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def do_lookup():
with LoggingContext("one") as ctx:
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
@ -89,10 +89,9 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
servers = yield resolver.resolve_service(service_name)
dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -112,11 +111,12 @@ class SrvResolverTestCase(unittest.TestCase):
entry.expires = 999999999
cache = {service_name: [entry]}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
resolver = SrvResolver(
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
)
servers = yield resolver.resolve_service(service_name)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
@ -131,9 +131,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
yield resolver.resolve_service(service_name)
@defer.inlineCallbacks
def test_name_error(self):
@ -144,10 +145,9 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
servers = yield resolver.resolve_service(service_name)
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
lookup_deferred.callback((