diff --git a/changelog.d/14988.misc b/changelog.d/14988.misc new file mode 100644 index 000000000..93ceaeafc --- /dev/null +++ b/changelog.d/14988.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 93de1c97e..11e683b70 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,9 +32,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - |tests/http/federation/test_matrix_federation_agent.py - |tests/http/federation/test_srv_resolver.py - |tests/http/test_proxyagent.py |tests/module_api/test_api.py |tests/rest/media/v1/test_media_storage.py |tests/server.py @@ -92,6 +89,9 @@ disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.http.*] +disallow_untyped_defs = True + [mypy-tests.logging.*] disallow_untyped_defs = True diff --git a/synapse/http/client.py b/synapse/http/client.py index 4eb740c04..a05f29793 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -44,6 +44,7 @@ from twisted.internet.interfaces import ( IAddress, IDelayedCall, IHostResolution, + IReactorCore, IReactorPluggableNameResolver, IReactorTime, IResolutionReceiver, @@ -226,7 +227,9 @@ class _IPBlacklistingResolver: return recv -@implementer(ISynapseReactor) +# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer +# of IReactorCore seems to keep mypy-zope happier. +@implementer(IReactorCore, ISynapseReactor) class BlacklistingReactorWrapper: """ A Reactor wrapper which will prevent DNS resolution to blacklisted IP diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 18899bc6d..94ef737b9 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -38,7 +38,6 @@ from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials -from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor: IReactorCore, - proxy_reactor: Optional[ISynapseReactor] = None, + proxy_reactor: Optional[IReactorCore] = None, contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 093537ade..528cdee34 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -19,13 +19,15 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 -def get_test_https_policy(): +def get_test_https_policy() -> BrowserLikePolicyForHTTPS: """Get a test IPolicyForHTTPS which trusts the test CA cert Returns: @@ -39,7 +41,7 @@ def get_test_https_policy(): return BrowserLikePolicyForHTTPS(trustRoot=trust_root) -def get_test_ca_cert_file(): +def get_test_ca_cert_file() -> str: """Get the path to the test CA cert The keypair is generated with: @@ -51,7 +53,7 @@ def get_test_ca_cert_file(): return os.path.join(os.path.dirname(__file__), "ca.crt") -def get_test_key_file(): +def get_test_key_file() -> str: """get the path to the test key The key file is made with: @@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" - def __init__(self, sanlist): + def __init__(self, sanlist: List[bytes]): """ Args: - sanlist: list[bytes]: a list of subjectAltName values for the cert + sanlist: a list of subjectAltName values for the cert """ self._cert_file = create_test_cert_file(sanlist) - def serverConnectionForTLS(self, tlsProtocol): + def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection: ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_certificate_file(self._cert_file) ctx.use_privatekey_file(get_test_key_file()) return Connection(ctx, None) + + +# A dummy address, useful for tests that use FakeTransport and don't care about where +# packets are going to/coming from. +dummy_address = IPv4Address("TCP", "127.0.0.1", 80) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 992d8f94f..acfdcd3bc 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import Any, Awaitable, Callable, Generator, List, Optional, cast from unittest.mock import Mock, patch import treq @@ -24,14 +24,19 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions -from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.defer import Deferred +from twisted.internet.endpoints import _WrappingProtocol +from twisted.internet.interfaces import ( + IOpenSSLClientConnectionCreator, + IProtocolFactory, +) from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers -from twisted.web.iweb import IPolicyForHTTPS +from twisted.web.iweb import IPolicyForHTTPS, IResponse from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS @@ -42,11 +47,21 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.types import ISynapseReactor from synapse.util.caches.ttlcache import TTLCache from tests import unittest -from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_ca_cert_file, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.utils import default_config @@ -54,15 +69,17 @@ logger = logging.getLogger(__name__) # Once Async Mocks or lambdas are supported this can go away. -def generate_resolve_service(result): - async def resolve_service(_): +def generate_resolve_service( + result: List[Server], +) -> Callable[[Any], Awaitable[List[Server]]]: + async def resolve_service(_: Any) -> List[Server]: return result return resolve_service class MatrixFederationAgentTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() @@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase): self.tls_factory = FederationPolicyForHTTPS(config) - self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) - self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) + self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache( + "test_cache", timer=self.reactor.seconds + ) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), @@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase): self, client_factory: IProtocolFactory, ssl: bool = True, - expected_sni: bytes = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + expected_sni: Optional[bytes] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection Args: @@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase): if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) - + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via @@ -125,7 +146,8 @@ class MatrixFederationAgentTests(unittest.TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert isinstance(client_protocol, _WrappingProtocol) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -136,6 +158,7 @@ class MatrixFederationAgentTests(unittest.TestCase): ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) # fish the test server back out of the server-side TLS protocol. http_protocol = server_protocol.wrappedProtocol # grab a hold of the TLS connection, in case it gets torn down @@ -144,6 +167,7 @@ class MatrixFederationAgentTests(unittest.TestCase): http_protocol = server_protocol tls_connection = None + assert isinstance(http_protocol, HTTPChannel) # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) @@ -159,12 +183,14 @@ class MatrixFederationAgentTests(unittest.TestCase): return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri: bytes): + def _make_get_request( + self, uri: bytes + ) -> Generator["Deferred[object]", object, IResponse]: """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b"GET", uri) + fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -172,8 +198,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # should have reset logcontext to the sentinel _check_logcontext(SENTINEL_CONTEXT) + fetch_res: IResponse try: - fetch_res = yield fetch_d + fetch_res = yield fetch_d # type: ignore[misc, assignment] return fetch_res except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) @@ -216,7 +243,7 @@ class MatrixFederationAgentTests(unittest.TestCase): request: Request, content: bytes, headers: Optional[dict] = None, - ): + ) -> None: """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -237,16 +264,16 @@ class MatrixFederationAgentTests(unittest.TestCase): because it is created too early during setUp """ return MatrixFederationAgent( - reactor=self.reactor, + reactor=cast(ISynapseReactor, self.reactor), tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) - def test_get(self): + def test_get(self) -> None: """happy-path test of a GET request with an explicit port""" self._do_get() @@ -254,11 +281,11 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "proxy.com", "no_proxy": "testserv"}, ) - def test_get_bypass_proxy(self): + def test_get_bypass_proxy(self) -> None: """test of a GET request with an explicit port and bypass proxy""" self._do_get() - def _do_get(self): + def _do_get(self) -> None: """test of a GET request with an explicit port""" self.agent = self._make_agent() @@ -318,7 +345,7 @@ class MatrixFederationAgentTests(unittest.TestCase): @patch.dict( os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_http_proxy(self): + def test_get_via_http_proxy(self) -> None: """test for federation request through a http proxy""" self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) @@ -326,7 +353,7 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_http_proxy_with_auth(self): + def test_get_via_http_proxy_with_auth(self) -> None: """test for federation request through a http proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" @@ -335,7 +362,7 @@ class MatrixFederationAgentTests(unittest.TestCase): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_get_via_https_proxy(self): + def test_get_via_https_proxy(self) -> None: """test for federation request through a https proxy""" self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) @@ -343,7 +370,7 @@ class MatrixFederationAgentTests(unittest.TestCase): os.environ, {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, ) - def test_get_via_https_proxy_with_auth(self): + def test_get_via_https_proxy_with_auth(self) -> None: """test for federation request through a https proxy with authentication""" self._do_get_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" @@ -353,7 +380,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https federation request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -418,10 +445,12 @@ class MatrixFederationAgentTests(unittest.TestCase): # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -451,6 +480,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -491,7 +521,7 @@ class MatrixFederationAgentTests(unittest.TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) - def test_get_ip_address(self): + def test_get_ip_address(self) -> None: """ Test the behaviour when the server name contains an explicit IP (with no port) """ @@ -526,7 +556,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address(self): + def test_get_ipv6_address(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) @@ -562,7 +592,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_ipv6_address_with_port(self): + def test_get_ipv6_address_with_port(self) -> None: """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) @@ -598,7 +628,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_hostname_bad_cert(self): + def test_get_hostname_bad_cert(self) -> None: """ Test the behaviour when the certificate on the server doesn't match the hostname """ @@ -651,7 +681,7 @@ class MatrixFederationAgentTests(unittest.TestCase): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_ip_address_bad_cert(self): + def test_get_ip_address_bad_cert(self) -> None: """ Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it @@ -684,7 +714,7 @@ class MatrixFederationAgentTests(unittest.TestCase): failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) - def test_get_no_srv_no_well_known(self): + def test_get_no_srv_no_well_known(self) -> None: """ Test the behaviour when the server name has no port, no SRV, and no well-known """ @@ -740,7 +770,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known(self): + def test_get_well_known(self) -> None: """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() @@ -802,7 +832,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_well_known_redirect(self): + def test_get_well_known_redirect(self) -> None: """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ @@ -892,7 +922,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) - def test_get_invalid_well_known(self): + def test_get_invalid_well_known(self) -> None: """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ @@ -945,7 +975,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_unsigned_cert(self): + def test_get_well_known_unsigned_cert(self) -> None: """Test the behaviour when the .well-known server presents a cert not signed by a CA """ @@ -969,7 +999,7 @@ class MatrixFederationAgentTests(unittest.TestCase): ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( - self.reactor, + cast(ISynapseReactor, self.reactor), Agent(self.reactor, contextFactory=tls_factory), b"test-agent", well_known_cache=self.well_known_cache, @@ -999,7 +1029,7 @@ class MatrixFederationAgentTests(unittest.TestCase): b"_matrix._tcp.testserv" ) - def test_get_hostname_srv(self): + def test_get_hostname_srv(self) -> None: """ Test the behaviour when there is a single SRV record """ @@ -1041,7 +1071,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_get_well_known_srv(self): + def test_get_well_known_srv(self) -> None: """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ @@ -1101,7 +1131,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_servername(self): + def test_idna_servername(self) -> None: """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() @@ -1163,7 +1193,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_idna_srv_target(self): + def test_idna_srv_target(self) -> None: """test the behaviour when the target of a SRV record has idna chars""" self.agent = self._make_agent() @@ -1206,7 +1236,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - def test_well_known_cache(self): + def test_well_known_cache(self) -> None: self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( @@ -1262,7 +1292,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"other-server") - def test_well_known_cache_with_temp_failure(self): + def test_well_known_cache_with_temp_failure(self) -> None: """Test that we refetch well-known before the cache expires, and that it ignores transient errors. """ @@ -1341,7 +1371,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) - def test_well_known_too_large(self): + def test_well_known_too_large(self) -> None: """A well-known query that returns a result which is too large should be rejected.""" self.reactor.lookups["testserv"] = "1.2.3.4" @@ -1367,7 +1397,7 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) - def test_srv_fallbacks(self): + def test_srv_fallbacks(self) -> None: """Test that other SRV results are tried if the first one fails.""" self.agent = self._make_agent() @@ -1427,7 +1457,7 @@ class MatrixFederationAgentTests(unittest.TestCase): class TestCachePeriodFromHeaders(unittest.TestCase): - def test_cache_control(self): + def test_cache_control(self) -> None: # uppercase self.assertEqual( _cache_period_from_headers( @@ -1464,7 +1494,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): 0, ) - def test_expires(self): + def test_expires(self) -> None: self.assertEqual( _cache_period_from_headers( Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), @@ -1491,14 +1521,14 @@ class TestCachePeriodFromHeaders(unittest.TestCase): self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) -def _check_logcontext(context): +def _check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate @@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") @@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 77ce8432a..7748f56ee 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Dict, Generator, List, Tuple, cast from unittest.mock import Mock from twisted.internet import defer @@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred from twisted.internet.error import ConnectError from twisted.names import dns, error -from synapse.http.federation.srv_resolver import SrvResolver +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.logging.context import LoggingContext, current_context from tests import unittest @@ -28,7 +28,7 @@ from tests.utils import MockClock class SrvResolverTestCase(unittest.TestCase): - def test_resolve(self): + def test_resolve(self) -> None: dns_client_mock = Mock() service_name = b"test_service.example.com" @@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase): type=dns.SRV, payload=dns.Record_SRV(target=host_name) ) - result_deferred = Deferred() + result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock.lookupService.return_value = result_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @defer.inlineCallbacks - def do_lookup(): + def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - result = yield defer.ensureDeferred(resolve_d) + result: List[Server] + result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment] # should have restored our context self.assertIs(current_context(), ctx) @@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks - def test_from_cache_expired_and_dns_fail(self): + def test_from_cache_expired_and_dns_fail( + self, + ) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) @@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache(self) -> Generator["Deferred[object]", object, None]: clock = MockClock() dns_client_mock = Mock(spec_set=["lookupService"]) @@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase): entry.priority = 0 entry.weight = 0 - cache = {service_name: [entry]} + cache = {service_name: [cast(Server, entry)]} resolver = SrvResolver( dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertFalse(dns_client_mock.lookupService.called) @@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase): self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks - def test_empty_cache(self): + def test_empty_cache(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks - def test_name_error(self): + def test_name_error(self) -> Generator["Deferred[object]", object, None]: dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) service_name = b"test_service.example.com" - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) + servers: List[Server] + servers = yield defer.ensureDeferred( + resolver.resolve_service(service_name) + ) # type: ignore[assignment] self.assertEqual(len(servers), 0) self.assertEqual(len(cache), 0) - def test_disabled_service(self): + def test_disabled_service(self) -> None: """ test the behaviour when there is a single record which is ".". """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in failureResultOf. @@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase): self.failureResultOf(resolve_d, ConnectError) - def test_non_srv_answer(self): + def test_non_srv_answer(self) -> None: """ test the behaviour when the dns server gives us a spurious non-SRV response """ service_name = b"test_service.example.com" - lookup_deferred = Deferred() + lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() dns_client_mock = Mock() dns_client_mock.lookupService.return_value = lookup_deferred - cache = {} + cache: Dict[bytes, List[Server]] = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) # Old versions of Twisted don't have an ensureDeferred in successResultOf. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 5071f8357..36472e57a 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: return method_name -def _hash_stack(stack: List[inspect.FrameInfo]): +def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]: """Turns a stack into a hashable value that can be put into a set.""" return tuple(_format_stack_frame(frame) for frame in stack) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 391196425..ec6aacf23 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -11,28 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from twisted.web.server import Request from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase class _AsyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_async"}) class _SyncTestCustomEndpoint: - def __init__(self, config, module_api): + def __init__(self, config: JsonDict, module_api: Any) -> None: pass - async def handle_request(self, request): + async def handle_request(self, request: Request) -> None: + assert isinstance(request, SynapseRequest) respond_with_json(request, 200, {"some_key": "some_value_sync"}) @@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase): and async handlers. """ - def test_async(self): + def test_async(self) -> None: handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) @@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) - def test_sync(self): + def test_sync(self) -> None: handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 7e2f2a01c..9cfe1ad0d 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -13,10 +13,12 @@ # limitations under the License. from io import BytesIO +from typing import Tuple, Union from unittest.mock import Mock from netaddr import IPSet +from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol @@ -28,6 +30,7 @@ from synapse.http.client import ( BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + _DiscardBodyWithMaxSizeProtocol, read_body_with_max_size, ) @@ -36,7 +39,9 @@ from tests.unittest import TestCase class ReadBodyWithMaxSizeTests(TestCase): - def _build_response(self, length=UNKNOWN_LENGTH): + def _build_response( + self, length: Union[int, str] = UNKNOWN_LENGTH + ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: """Start reading the body, returns the response, result and proto""" response = Mock(length=length) result = BytesIO() @@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase): return result, deferred, protocol - def _assert_error(self, deferred, protocol): + def _assert_error( + self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol + ) -> None: """Ensure that the expected error is received.""" - self.assertIsInstance(deferred.result, Failure) + assert isinstance(deferred.result, Failure) self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) - protocol.transport.abortConnection.assert_called_once() + assert protocol.transport is not None + # type-ignore: presumably abortConnection has been replaced with a Mock. + protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] - def _cleanup_error(self, deferred): + def _cleanup_error(self, deferred: "Deferred[int]") -> None: """Ensure that the error in the Deferred is handled gracefully.""" called = [False] - def errback(f): + def errback(f: Failure) -> None: called[0] = True deferred.addErrback(errback) self.assertTrue(called[0]) - def test_no_error(self): + def test_no_error(self) -> None: """A response that is NOT too large.""" result, deferred, protocol = self._build_response() @@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self.assertEqual(result.getvalue(), b"12345") self.assertEqual(deferred.result, 5) - def test_too_large(self): + def test_too_large(self) -> None: """A response which is too large raises an exception.""" result, deferred, protocol = self._build_response() @@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_multiple_packets(self): + def test_multiple_packets(self) -> None: """Data should be accumulated through mutliple packets.""" result, deferred, protocol = self._build_response() @@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self.assertEqual(result.getvalue(), b"1234") self.assertEqual(deferred.result, 4) - def test_additional_data(self): + def test_additional_data(self) -> None: """A connection can receive data after being closed.""" result, deferred, protocol = self._build_response() @@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase): self._assert_error(deferred, protocol) self._cleanup_error(deferred) - def test_content_length(self): + def test_content_length(self) -> None: """The body shouldn't be read (at all) if the Content-Length header is too large.""" result, deferred, protocol = self._build_response(length=10) @@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase): class BlacklistingAgentTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" @@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase): self.ip_whitelist = IPSet([self.allowed_ip.decode()]) self.ip_blacklist = IPSet(["5.0.0.0/8"]) - def test_reactor(self): + def test_reactor(self) -> None: """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" agent = Agent( BlacklistingReactorWrapper( @@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase): response = self.successResultOf(d) self.assertEqual(response.code, 200) - def test_agent(self): + def test_agent(self) -> None: """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" agent = BlacklistingAgentWrapper( Agent(self.reactor), diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index a801f002a..8c18e5688 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -17,7 +17,7 @@ from tests import unittest class ServerNameTestCase(unittest.TestCase): - def test_parse_server_name(self): + def test_parse_server_name(self) -> None: test_data = { "localhost": ("localhost", None), "my-example.com:1234": ("my-example.com", 1234), @@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase): for i, o in test_data.items(): self.assertEqual(parse_server_name(i), o) - def test_validate_bad_server_names(self): + def test_validate_bad_server_names(self) -> None: test_data = [ "", # empty "localhost:http", # non-numeric port diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index be9eaf34e..fdd22a8e9 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Generator from unittest.mock import Mock from netaddr import IPSet from parameterized import parameterized from twisted.internet import defer -from twisted.internet.defer import TimeoutError +from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from twisted.web.client import ResponseNeverReceived from twisted.web.http import HTTPChannel @@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + LoggingContextOrSentinel, + current_context, +) +from synapse.server import HomeServer +from synapse.util import Clock from tests.server import FakeTransport from tests.unittest import HomeserverTestCase -def check_logcontext(context): +def check_logcontext(context: LoggingContextOrSentinel) -> None: current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" - def test_client_get(self): + def test_client_get(self) -> None: """ happy-path test of a GET request """ @defer.inlineCallbacks - def do_request(): + def do_request() -> Generator["Deferred[object]", object, object]: with LoggingContext("one") as context: fetch_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar") @@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase): # check the response is as expected self.assertEqual(res, {"a": 1}) - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) ) @@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase): f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist @@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase): f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) - def test_client_gets_headers(self): + def test_client_gets_headers(self) -> None: """ Once the client gets the headers, _request returns successfully. """ @@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase): self.assertEqual(r.code, 200) @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) - def test_timeout_reading_body(self, method_name: str): + def test_timeout_reading_body(self, method_name: str) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a RequestSendFailed with can_retry. @@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase): self.assertTrue(f.value.can_retry) self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) - def test_client_requires_trailing_slashes(self): + def test_client_requires_trailing_slashes(self) -> None: """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a @@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase): r = self.successResultOf(d) self.assertEqual(r, {}) - def test_client_does_not_retry_on_400_plus(self): + def test_client_does_not_retry_on_400_plus(self) -> None: """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. @@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase): # We should get a 404 failure response self.failureResultOf(d) - def test_client_sends_body(self): + def test_client_sends_body(self) -> None: defer.ensureDeferred( self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} @@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase): content = request.content.read() self.assertEqual(content, b'{"a":"b"}') - def test_closes_connection(self): + def test_closes_connection(self) -> None: """Check that the client closes unused HTTP connections""" d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) @@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase): self.assertTrue(conn.disconnecting) @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) - def test_json_error(self, return_value): + def test_json_error(self, return_value: bytes) -> None: """ Test what happens if invalid JSON is returned from the remote endpoint. """ @@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase): f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) - def test_too_big(self): + def test_too_big(self) -> None: """ Test what happens if a huge response is returned from the remote endpoint. """ diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 2db77c6a7..a81794073 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -14,7 +14,7 @@ import base64 import logging import os -from typing import Iterable, Optional +from typing import List, Optional from unittest.mock import patch import treq @@ -22,7 +22,11 @@ from netaddr import IPSet from parameterized import parameterized from twisted.internet import interfaces # noqa: F401 -from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint +from twisted.internet.endpoints import ( + HostnameEndpoint, + _WrapperEndpoint, + _WrappingProtocol, +) from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol @@ -32,7 +36,11 @@ from synapse.http.client import BlacklistingReactorWrapper from synapse.http.connectproxyclient import ProxyCredentials from synapse.http.proxyagent import ProxyAgent, parse_proxy -from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.http import ( + TestServerTLSConnectionFactory, + dummy_address, + get_test_https_policy, +) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase @@ -183,7 +191,7 @@ class ProxyParserTests(TestCase): expected_hostname: bytes, expected_port: int, expected_credentials: Optional[bytes], - ): + ) -> None: """ Tests that a given proxy URL will be broken into the components. Args: @@ -209,7 +217,7 @@ class ProxyParserTests(TestCase): class MatrixFederationAgentTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() def _make_connection( @@ -218,7 +226,7 @@ class MatrixFederationAgentTests(TestCase): server_factory: IProtocolFactory, ssl: bool = False, expected_sni: Optional[bytes] = None, - tls_sanlist: Optional[Iterable[bytes]] = None, + tls_sanlist: Optional[List[bytes]] = None, ) -> IProtocol: """Builds a test server, and completes the outgoing client connection @@ -244,7 +252,8 @@ class MatrixFederationAgentTests(TestCase): if ssl: server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) - server_protocol = server_factory.buildProtocol(None) + server_protocol = server_factory.buildProtocol(dummy_address) + assert server_protocol is not None # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via @@ -252,7 +261,8 @@ class MatrixFederationAgentTests(TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) + client_protocol = client_factory.buildProtocol(dummy_address) + assert client_protocol is not None client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -263,6 +273,7 @@ class MatrixFederationAgentTests(TestCase): ) if ssl: + assert isinstance(server_protocol, TLSMemoryBIOProtocol) http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: @@ -288,7 +299,7 @@ class MatrixFederationAgentTests(TestCase): scheme: bytes, hostname: bytes, path: bytes, - ): + ) -> None: """Runs a test case for a direct connection not going through a proxy. Args: @@ -319,6 +330,7 @@ class MatrixFederationAgentTests(TestCase): ssl=is_https, expected_sni=hostname if is_https else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -339,34 +351,34 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - def test_http_request(self): + def test_http_request(self) -> None: agent = ProxyAgent(self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") - def test_https_request(self): + def test_https_request(self) -> None: agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") - def test_http_request_use_proxy_empty_environment(self): + def test_http_request_use_proxy_empty_environment(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"}) - def test_http_request_via_uppercase_no_proxy(self): + def test_http_request_via_uppercase_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"} ) - def test_http_request_via_no_proxy(self): + def test_http_request_via_no_proxy(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"} ) - def test_https_request_via_no_proxy(self): + def test_https_request_via_no_proxy(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -375,12 +387,12 @@ class MatrixFederationAgentTests(TestCase): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) - def test_http_request_via_no_proxy_star(self): + def test_http_request_via_no_proxy_star(self) -> None: agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) - def test_https_request_via_no_proxy_star(self): + def test_https_request_via_no_proxy_star(self) -> None: agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -389,7 +401,7 @@ class MatrixFederationAgentTests(TestCase): self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) - def test_http_request_via_proxy(self): + def test_http_request_via_proxy(self) -> None: """ Tests that requests can be made through a proxy. """ @@ -401,7 +413,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, ) - def test_http_request_via_proxy_with_auth(self): + def test_http_request_via_proxy_with_auth(self) -> None: """ Tests that authenticated requests can be made through a proxy. """ @@ -412,7 +424,7 @@ class MatrixFederationAgentTests(TestCase): @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) - def test_http_request_via_https_proxy(self): + def test_http_request_via_https_proxy(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None ) @@ -424,13 +436,13 @@ class MatrixFederationAgentTests(TestCase): "no_proxy": "unused.com", }, ) - def test_http_request_via_https_proxy_with_auth(self): + def test_http_request_via_https_proxy_with_auth(self) -> None: self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): + def test_https_request_via_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=None @@ -440,7 +452,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_proxy_with_auth(self): + def test_https_request_via_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" @@ -449,7 +461,7 @@ class MatrixFederationAgentTests(TestCase): @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) - def test_https_request_via_https_proxy(self): + def test_https_request_via_https_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=None @@ -459,7 +471,7 @@ class MatrixFederationAgentTests(TestCase): os.environ, {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, ) - def test_https_request_via_https_proxy_with_auth(self): + def test_https_request_via_https_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" @@ -469,7 +481,7 @@ class MatrixFederationAgentTests(TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: @@ -501,6 +513,7 @@ class MatrixFederationAgentTests(TestCase): tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -542,7 +555,7 @@ class MatrixFederationAgentTests(TestCase): self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, - ): + ) -> None: """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: @@ -606,10 +619,12 @@ class MatrixFederationAgentTests(TestCase): # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() - ).buildProtocol(None) + ).buildProtocol(dummy_address) + assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport + assert proxy_server_transport is not None server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the @@ -644,6 +659,7 @@ class MatrixFederationAgentTests(TestCase): # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -667,7 +683,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_http_request_via_proxy_with_blacklist(self): + def test_http_request_via_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -691,6 +707,7 @@ class MatrixFederationAgentTests(TestCase): http_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(http_server, HTTPChannel) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -712,7 +729,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) - def test_https_request_via_uppercase_proxy_with_blacklist(self): + def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None: # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper( @@ -737,11 +754,15 @@ class MatrixFederationAgentTests(TestCase): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory() ) + assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport + assert isinstance(s2c_transport, FakeTransport) client_protocol = s2c_transport.other + assert isinstance(client_protocol, _WrappingProtocol) c2s_transport = client_protocol.transport + assert isinstance(c2s_transport, FakeTransport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -762,8 +783,10 @@ class MatrixFederationAgentTests(TestCase): # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) - ssl_protocol = ssl_factory.buildProtocol(None) + ssl_protocol = ssl_factory.buildProtocol(dummy_address) + assert isinstance(ssl_protocol, TLSMemoryBIOProtocol) http_server = ssl_protocol.wrappedProtocol + assert isinstance(http_server, HTTPChannel) ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol) @@ -797,28 +820,28 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) - def test_proxy_with_no_scheme(self): + def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) - def test_proxy_with_unsupported_scheme(self): + def test_proxy_with_unsupported_scheme(self) -> None: with self.assertRaises(ValueError): ProxyAgent(self.reactor, use_proxy=True) @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) - def test_proxy_with_http_scheme(self): + def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) - def test_proxy_with_https_scheme(self): + def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) + assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) self.assertEqual( https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" ) @@ -828,7 +851,7 @@ class MatrixFederationAgentTests(TestCase): def _wrap_server_factory_for_tls( - factory: IProtocolFactory, sanlist: Iterable[bytes] = None + factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None ) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory @@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory: return server_factory -def _log_request(request: str): +def _log_request(request: str) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info(f"Completed request {request}") diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 46166292f..c8d215b6d 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -14,7 +14,7 @@ import json from http import HTTPStatus from io import BytesIO -from typing import Tuple +from typing import Tuple, Union from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -33,7 +33,7 @@ from tests import unittest from tests.http.server._base import test_disconnect -def make_request(content): +def make_request(content: Union[bytes, JsonDict]) -> Mock: """Make an object that acts enough like a request.""" request = Mock(spec=["method", "uri", "content"]) @@ -47,7 +47,7 @@ def make_request(content): class TestServletUtils(unittest.TestCase): - def test_parse_json_value(self): + def test_parse_json_value(self) -> None: """Basic tests for parse_json_value_from_request.""" # Test round-tripping. obj = {"foo": 1} @@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase): with self.assertRaises(SynapseError): parse_json_value_from_request(make_request(b'{"foo": Infinity}')) - def test_parse_json_object(self): + def test_parse_json_object(self) -> None: """Basic tests for parse_json_object_from_request.""" # Test empty. result = parse_json_object_from_request( diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py index c85a3665c..010601da4 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py @@ -17,22 +17,24 @@ from netaddr import IPSet from twisted.internet import defer from twisted.internet.error import DNSLookupError +from twisted.test.proto_helpers import MemoryReactor from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SimpleHttpClientTests(HomeserverTestCase): - def prepare(self, reactor, clock, hs: "HomeServer"): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None: # Add a DNS entry for a test server self.reactor.lookups["testserv"] = "1.2.3.4" self.cl = hs.get_simple_http_client() - def test_dns_error(self): + def test_dns_error(self) -> None: """ If the DNS lookup returns an error, it will bubble up. """ @@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase): f = self.failureResultOf(d) self.assertIsInstance(f.value, DNSLookupError) - def test_client_connection_refused(self): + def test_client_connection_refused(self) -> None: d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) self.pump() @@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIs(f.value, e) - def test_client_never_connect(self): + def test_client_never_connect(self) -> None: """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. @@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_connect_no_response(self): + def test_client_connect_no_response(self) -> None: """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. @@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestTimedOutError) - def test_client_ip_range_blacklist(self): + def test_client_ip_range_blacklist(self) -> None: """Ensure that Synapse does not try to connect to blacklisted IPs""" # Add some DNS entries we'll blacklist diff --git a/tests/http/test_site.py b/tests/http/test_site.py index b2dbf76d3..9a78fede9 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py @@ -13,18 +13,20 @@ # limitations under the License. from twisted.internet.address import IPv6Address -from twisted.test.proto_helpers import StringTransport +from twisted.test.proto_helpers import MemoryReactor, StringTransport from synapse.app.homeserver import SynapseHomeServer +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class SynapseRequestTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) - def test_large_request(self): + def test_large_request(self) -> None: """overlarge HTTP requests should be rejected""" self.hs.start_listening() diff --git a/tests/server.py b/tests/server.py index b1730fcc8..237bcad8b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -70,7 +70,7 @@ from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine -from synapse.types import JsonDict +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests.utils import ( @@ -401,7 +401,9 @@ def make_request( return channel -@implementer(IReactorPluggableNameResolver) +# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly +# marking this as an implementer of the latter seems to keep mypy-zope happier. +@implementer(IReactorPluggableNameResolver, ISynapseReactor) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread.