mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Properly typecheck types.http (#14988)
* Tweak http types in Synapse AFACIS these are correct, and they make mypy happier on tests.http. * Type hints for test_proxyagent * type hints for test_srv_resolver * test_matrix_federation_agent * tests.http.server._base * tests.http.__init__ * tests.http.test_additional_resource * tests.http.test_client * tests.http.test_endpoint * tests.http.test_matrixfederationclient * tests.http.test_servlet * tests.http.test_simple_client * tests.http.test_site * One fixup in tests.server * Untyped defs * Changelog * Fixup syntax for Python 3.7 * Fix olddeps syntax * Use a twisted IPv4 addr for dummy_address * Fix typo, thanks Sean Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Remove redundant `Optional` --------- Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
parent
5fdc12f482
commit
d0fed7a37b
1
changelog.d/14988.misc
Normal file
1
changelog.d/14988.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve type hints.
|
6
mypy.ini
6
mypy.ini
@ -32,9 +32,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/http/federation/test_matrix_federation_agent.py
|
|
||||||
|tests/http/federation/test_srv_resolver.py
|
|
||||||
|tests/http/test_proxyagent.py
|
|
||||||
|tests/module_api/test_api.py
|
|tests/module_api/test_api.py
|
||||||
|tests/rest/media/v1/test_media_storage.py
|
|tests/rest/media/v1/test_media_storage.py
|
||||||
|tests/server.py
|
|tests/server.py
|
||||||
@ -92,6 +89,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-tests.handlers.*]
|
[mypy-tests.handlers.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-tests.http.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.logging.*]
|
[mypy-tests.logging.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ from twisted.internet.interfaces import (
|
|||||||
IAddress,
|
IAddress,
|
||||||
IDelayedCall,
|
IDelayedCall,
|
||||||
IHostResolution,
|
IHostResolution,
|
||||||
|
IReactorCore,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTime,
|
IReactorTime,
|
||||||
IResolutionReceiver,
|
IResolutionReceiver,
|
||||||
@ -226,7 +227,9 @@ class _IPBlacklistingResolver:
|
|||||||
return recv
|
return recv
|
||||||
|
|
||||||
|
|
||||||
@implementer(ISynapseReactor)
|
# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer
|
||||||
|
# of IReactorCore seems to keep mypy-zope happier.
|
||||||
|
@implementer(IReactorCore, ISynapseReactor)
|
||||||
class BlacklistingReactorWrapper:
|
class BlacklistingReactorWrapper:
|
||||||
"""
|
"""
|
||||||
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
|
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
|
||||||
|
@ -38,7 +38,6 @@ from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
|
|||||||
|
|
||||||
from synapse.http import redact_uri
|
from synapse.http import redact_uri
|
||||||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
|
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
|
||||||
from synapse.types import ISynapseReactor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reactor: IReactorCore,
|
reactor: IReactorCore,
|
||||||
proxy_reactor: Optional[ISynapseReactor] = None,
|
proxy_reactor: Optional[IReactorCore] = None,
|
||||||
contextFactory: Optional[IPolicyForHTTPS] = None,
|
contextFactory: Optional[IPolicyForHTTPS] = None,
|
||||||
connectTimeout: Optional[float] = None,
|
connectTimeout: Optional[float] = None,
|
||||||
bindAddress: Optional[bytes] = None,
|
bindAddress: Optional[bytes] = None,
|
||||||
|
@ -19,13 +19,15 @@ from zope.interface import implementer
|
|||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import Connection
|
from OpenSSL.SSL import Connection
|
||||||
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||||
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOProtocol
|
||||||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def get_test_https_policy():
|
def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
|
||||||
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -39,7 +41,7 @@ def get_test_https_policy():
|
|||||||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||||
|
|
||||||
|
|
||||||
def get_test_ca_cert_file():
|
def get_test_ca_cert_file() -> str:
|
||||||
"""Get the path to the test CA cert
|
"""Get the path to the test CA cert
|
||||||
|
|
||||||
The keypair is generated with:
|
The keypair is generated with:
|
||||||
@ -51,7 +53,7 @@ def get_test_ca_cert_file():
|
|||||||
return os.path.join(os.path.dirname(__file__), "ca.crt")
|
return os.path.join(os.path.dirname(__file__), "ca.crt")
|
||||||
|
|
||||||
|
|
||||||
def get_test_key_file():
|
def get_test_key_file() -> str:
|
||||||
"""get the path to the test key
|
"""get the path to the test key
|
||||||
|
|
||||||
The key file is made with:
|
The key file is made with:
|
||||||
@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory:
|
|||||||
"""An SSL connection creator which returns connections which present a certificate
|
"""An SSL connection creator which returns connections which present a certificate
|
||||||
signed by our test CA."""
|
signed by our test CA."""
|
||||||
|
|
||||||
def __init__(self, sanlist):
|
def __init__(self, sanlist: List[bytes]):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
sanlist: list[bytes]: a list of subjectAltName values for the cert
|
sanlist: a list of subjectAltName values for the cert
|
||||||
"""
|
"""
|
||||||
self._cert_file = create_test_cert_file(sanlist)
|
self._cert_file = create_test_cert_file(sanlist)
|
||||||
|
|
||||||
def serverConnectionForTLS(self, tlsProtocol):
|
def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
|
||||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
ctx.use_certificate_file(self._cert_file)
|
ctx.use_certificate_file(self._cert_file)
|
||||||
ctx.use_privatekey_file(get_test_key_file())
|
ctx.use_privatekey_file(get_test_key_file())
|
||||||
return Connection(ctx, None)
|
return Connection(ctx, None)
|
||||||
|
|
||||||
|
|
||||||
|
# A dummy address, useful for tests that use FakeTransport and don't care about where
|
||||||
|
# packets are going to/coming from.
|
||||||
|
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Iterable, Optional
|
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import treq
|
import treq
|
||||||
@ -24,14 +24,19 @@ from zope.interface import implementer
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
|
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
|
||||||
from twisted.internet.interfaces import IProtocolFactory
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.endpoints import _WrappingProtocol
|
||||||
|
from twisted.internet.interfaces import (
|
||||||
|
IOpenSSLClientConnectionCreator,
|
||||||
|
IProtocolFactory,
|
||||||
|
)
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
from twisted.web._newclient import ResponseNeverReceived
|
from twisted.web._newclient import ResponseNeverReceived
|
||||||
from twisted.web.client import Agent
|
from twisted.web.client import Agent
|
||||||
from twisted.web.http import HTTPChannel, Request
|
from twisted.web.http import HTTPChannel, Request
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IPolicyForHTTPS
|
from twisted.web.iweb import IPolicyForHTTPS, IResponse
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||||
@ -42,11 +47,21 @@ from synapse.http.federation.well_known_resolver import (
|
|||||||
WellKnownResolver,
|
WellKnownResolver,
|
||||||
_cache_period_from_headers,
|
_cache_period_from_headers,
|
||||||
)
|
)
|
||||||
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
|
from synapse.logging.context import (
|
||||||
|
SENTINEL_CONTEXT,
|
||||||
|
LoggingContext,
|
||||||
|
LoggingContextOrSentinel,
|
||||||
|
current_context,
|
||||||
|
)
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
from synapse.util.caches.ttlcache import TTLCache
|
from synapse.util.caches.ttlcache import TTLCache
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
|
from tests.http import (
|
||||||
|
TestServerTLSConnectionFactory,
|
||||||
|
dummy_address,
|
||||||
|
get_test_ca_cert_file,
|
||||||
|
)
|
||||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
@ -54,15 +69,17 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Once Async Mocks or lambdas are supported this can go away.
|
# Once Async Mocks or lambdas are supported this can go away.
|
||||||
def generate_resolve_service(result):
|
def generate_resolve_service(
|
||||||
async def resolve_service(_):
|
result: List[Server],
|
||||||
|
) -> Callable[[Any], Awaitable[List[Server]]]:
|
||||||
|
async def resolve_service(_: Any) -> List[Server]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return resolve_service
|
return resolve_service
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationAgentTests(unittest.TestCase):
|
class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
self.mock_resolver = Mock()
|
self.mock_resolver = Mock()
|
||||||
@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.tls_factory = FederationPolicyForHTTPS(config)
|
self.tls_factory = FederationPolicyForHTTPS(config)
|
||||||
|
|
||||||
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
|
self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache(
|
||||||
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
|
"test_cache", timer=self.reactor.seconds
|
||||||
|
)
|
||||||
|
self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache(
|
||||||
|
"test_cache", timer=self.reactor.seconds
|
||||||
|
)
|
||||||
self.well_known_resolver = WellKnownResolver(
|
self.well_known_resolver = WellKnownResolver(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
Agent(self.reactor, contextFactory=self.tls_factory),
|
Agent(self.reactor, contextFactory=self.tls_factory),
|
||||||
@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self,
|
self,
|
||||||
client_factory: IProtocolFactory,
|
client_factory: IProtocolFactory,
|
||||||
ssl: bool = True,
|
ssl: bool = True,
|
||||||
expected_sni: bytes = None,
|
expected_sni: Optional[bytes] = None,
|
||||||
tls_sanlist: Optional[Iterable[bytes]] = None,
|
tls_sanlist: Optional[List[bytes]] = None,
|
||||||
) -> HTTPChannel:
|
) -> HTTPChannel:
|
||||||
"""Builds a test server, and completes the outgoing client connection
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
Args:
|
Args:
|
||||||
@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
if ssl:
|
if ssl:
|
||||||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
||||||
|
|
||||||
server_protocol = server_factory.buildProtocol(None)
|
server_protocol = server_factory.buildProtocol(dummy_address)
|
||||||
|
assert server_protocol is not None
|
||||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||||
@ -125,7 +146,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
#
|
#
|
||||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
# stubbing that out here.
|
# stubbing that out here.
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(dummy_address)
|
||||||
|
assert isinstance(client_protocol, _WrappingProtocol)
|
||||||
client_protocol.makeConnection(
|
client_protocol.makeConnection(
|
||||||
FakeTransport(server_protocol, self.reactor, client_protocol)
|
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||||
)
|
)
|
||||||
@ -136,6 +158,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if ssl:
|
if ssl:
|
||||||
|
assert isinstance(server_protocol, TLSMemoryBIOProtocol)
|
||||||
# fish the test server back out of the server-side TLS protocol.
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
http_protocol = server_protocol.wrappedProtocol
|
http_protocol = server_protocol.wrappedProtocol
|
||||||
# grab a hold of the TLS connection, in case it gets torn down
|
# grab a hold of the TLS connection, in case it gets torn down
|
||||||
@ -144,6 +167,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
http_protocol = server_protocol
|
http_protocol = server_protocol
|
||||||
tls_connection = None
|
tls_connection = None
|
||||||
|
|
||||||
|
assert isinstance(http_protocol, HTTPChannel)
|
||||||
# give the reactor a pump to get the TLS juices flowing (if needed)
|
# give the reactor a pump to get the TLS juices flowing (if needed)
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
|
|
||||||
@ -159,12 +183,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
return http_protocol
|
return http_protocol
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_get_request(self, uri: bytes):
|
def _make_get_request(
|
||||||
|
self, uri: bytes
|
||||||
|
) -> Generator["Deferred[object]", object, IResponse]:
|
||||||
"""
|
"""
|
||||||
Sends a simple GET request via the agent, and checks its logcontext management
|
Sends a simple GET request via the agent, and checks its logcontext management
|
||||||
"""
|
"""
|
||||||
with LoggingContext("one") as context:
|
with LoggingContext("one") as context:
|
||||||
fetch_d = self.agent.request(b"GET", uri)
|
fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri)
|
||||||
|
|
||||||
# Nothing happened yet
|
# Nothing happened yet
|
||||||
self.assertNoResult(fetch_d)
|
self.assertNoResult(fetch_d)
|
||||||
@ -172,8 +198,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
# should have reset logcontext to the sentinel
|
# should have reset logcontext to the sentinel
|
||||||
_check_logcontext(SENTINEL_CONTEXT)
|
_check_logcontext(SENTINEL_CONTEXT)
|
||||||
|
|
||||||
|
fetch_res: IResponse
|
||||||
try:
|
try:
|
||||||
fetch_res = yield fetch_d
|
fetch_res = yield fetch_d # type: ignore[misc, assignment]
|
||||||
return fetch_res
|
return fetch_res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
|
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
|
||||||
@ -216,7 +243,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
request: Request,
|
request: Request,
|
||||||
content: bytes,
|
content: bytes,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Check that an incoming request looks like a valid .well-known request, and
|
"""Check that an incoming request looks like a valid .well-known request, and
|
||||||
send back the response.
|
send back the response.
|
||||||
"""
|
"""
|
||||||
@ -237,16 +264,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
because it is created too early during setUp
|
because it is created too early during setUp
|
||||||
"""
|
"""
|
||||||
return MatrixFederationAgent(
|
return MatrixFederationAgent(
|
||||||
reactor=self.reactor,
|
reactor=cast(ISynapseReactor, self.reactor),
|
||||||
tls_client_options_factory=self.tls_factory,
|
tls_client_options_factory=self.tls_factory,
|
||||||
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
|
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
|
||||||
ip_whitelist=IPSet(),
|
ip_whitelist=IPSet(),
|
||||||
ip_blacklist=IPSet(),
|
ip_blacklist=IPSet(),
|
||||||
_srv_resolver=self.mock_resolver,
|
_srv_resolver=self.mock_resolver,
|
||||||
_well_known_resolver=self.well_known_resolver,
|
_well_known_resolver=self.well_known_resolver,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get(self):
|
def test_get(self) -> None:
|
||||||
"""happy-path test of a GET request with an explicit port"""
|
"""happy-path test of a GET request with an explicit port"""
|
||||||
self._do_get()
|
self._do_get()
|
||||||
|
|
||||||
@ -254,11 +281,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"https_proxy": "proxy.com", "no_proxy": "testserv"},
|
{"https_proxy": "proxy.com", "no_proxy": "testserv"},
|
||||||
)
|
)
|
||||||
def test_get_bypass_proxy(self):
|
def test_get_bypass_proxy(self) -> None:
|
||||||
"""test of a GET request with an explicit port and bypass proxy"""
|
"""test of a GET request with an explicit port and bypass proxy"""
|
||||||
self._do_get()
|
self._do_get()
|
||||||
|
|
||||||
def _do_get(self):
|
def _do_get(self) -> None:
|
||||||
"""test of a GET request with an explicit port"""
|
"""test of a GET request with an explicit port"""
|
||||||
self.agent = self._make_agent()
|
self.agent = self._make_agent()
|
||||||
|
|
||||||
@ -318,7 +345,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
|
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
|
||||||
)
|
)
|
||||||
def test_get_via_http_proxy(self):
|
def test_get_via_http_proxy(self) -> None:
|
||||||
"""test for federation request through a http proxy"""
|
"""test for federation request through a http proxy"""
|
||||||
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
|
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
|
||||||
|
|
||||||
@ -326,7 +353,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
|
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
|
||||||
)
|
)
|
||||||
def test_get_via_http_proxy_with_auth(self):
|
def test_get_via_http_proxy_with_auth(self) -> None:
|
||||||
"""test for federation request through a http proxy with authentication"""
|
"""test for federation request through a http proxy with authentication"""
|
||||||
self._do_get_via_proxy(
|
self._do_get_via_proxy(
|
||||||
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
|
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
|
||||||
@ -335,7 +362,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
|
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
|
||||||
)
|
)
|
||||||
def test_get_via_https_proxy(self):
|
def test_get_via_https_proxy(self) -> None:
|
||||||
"""test for federation request through a https proxy"""
|
"""test for federation request through a https proxy"""
|
||||||
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
|
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
|
||||||
|
|
||||||
@ -343,7 +370,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
|
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
|
||||||
)
|
)
|
||||||
def test_get_via_https_proxy_with_auth(self):
|
def test_get_via_https_proxy_with_auth(self) -> None:
|
||||||
"""test for federation request through a https proxy with authentication"""
|
"""test for federation request through a https proxy with authentication"""
|
||||||
self._do_get_via_proxy(
|
self._do_get_via_proxy(
|
||||||
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
|
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
|
||||||
@ -353,7 +380,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self,
|
self,
|
||||||
expect_proxy_ssl: bool = False,
|
expect_proxy_ssl: bool = False,
|
||||||
expected_auth_credentials: Optional[bytes] = None,
|
expected_auth_credentials: Optional[bytes] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Send a https federation request via an agent and check that it is correctly
|
"""Send a https federation request via an agent and check that it is correctly
|
||||||
received at the proxy and client. The proxy can use either http or https.
|
received at the proxy and client. The proxy can use either http or https.
|
||||||
Args:
|
Args:
|
||||||
@ -418,10 +445,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
# now we make another test server to act as the upstream HTTP server.
|
# now we make another test server to act as the upstream HTTP server.
|
||||||
server_ssl_protocol = _wrap_server_factory_for_tls(
|
server_ssl_protocol = _wrap_server_factory_for_tls(
|
||||||
_get_test_protocol_factory()
|
_get_test_protocol_factory()
|
||||||
).buildProtocol(None)
|
).buildProtocol(dummy_address)
|
||||||
|
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
|
||||||
|
|
||||||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
||||||
proxy_server_transport = proxy_server.transport
|
proxy_server_transport = proxy_server.transport
|
||||||
|
assert proxy_server_transport is not None
|
||||||
server_ssl_protocol.makeConnection(proxy_server_transport)
|
server_ssl_protocol.makeConnection(proxy_server_transport)
|
||||||
|
|
||||||
# ... and replace the protocol on the proxy's transport with the
|
# ... and replace the protocol on the proxy's transport with the
|
||||||
@ -451,6 +480,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
|
|
||||||
# now there should be a pending request
|
# now there should be a pending request
|
||||||
http_server = server_ssl_protocol.wrappedProtocol
|
http_server = server_ssl_protocol.wrappedProtocol
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
@ -491,7 +521,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
json = self.successResultOf(treq.json_content(response))
|
json = self.successResultOf(treq.json_content(response))
|
||||||
self.assertEqual(json, {"a": 1})
|
self.assertEqual(json, {"a": 1})
|
||||||
|
|
||||||
def test_get_ip_address(self):
|
def test_get_ip_address(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name contains an explicit IP (with no port)
|
Test the behaviour when the server name contains an explicit IP (with no port)
|
||||||
"""
|
"""
|
||||||
@ -526,7 +556,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_ipv6_address(self):
|
def test_get_ipv6_address(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name contains an explicit IPv6 address
|
Test the behaviour when the server name contains an explicit IPv6 address
|
||||||
(with no port)
|
(with no port)
|
||||||
@ -562,7 +592,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_ipv6_address_with_port(self):
|
def test_get_ipv6_address_with_port(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name contains an explicit IPv6 address
|
Test the behaviour when the server name contains an explicit IPv6 address
|
||||||
(with explicit port)
|
(with explicit port)
|
||||||
@ -598,7 +628,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_hostname_bad_cert(self):
|
def test_get_hostname_bad_cert(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the certificate on the server doesn't match the hostname
|
Test the behaviour when the certificate on the server doesn't match the hostname
|
||||||
"""
|
"""
|
||||||
@ -651,7 +681,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
failure_reason = e.value.reasons[0]
|
failure_reason = e.value.reasons[0]
|
||||||
self.assertIsInstance(failure_reason.value, VerificationError)
|
self.assertIsInstance(failure_reason.value, VerificationError)
|
||||||
|
|
||||||
def test_get_ip_address_bad_cert(self):
|
def test_get_ip_address_bad_cert(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name contains an explicit IP, but
|
Test the behaviour when the server name contains an explicit IP, but
|
||||||
the server cert doesn't cover it
|
the server cert doesn't cover it
|
||||||
@ -684,7 +714,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
failure_reason = e.value.reasons[0]
|
failure_reason = e.value.reasons[0]
|
||||||
self.assertIsInstance(failure_reason.value, VerificationError)
|
self.assertIsInstance(failure_reason.value, VerificationError)
|
||||||
|
|
||||||
def test_get_no_srv_no_well_known(self):
|
def test_get_no_srv_no_well_known(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name has no port, no SRV, and no well-known
|
Test the behaviour when the server name has no port, no SRV, and no well-known
|
||||||
"""
|
"""
|
||||||
@ -740,7 +770,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_well_known(self):
|
def test_get_well_known(self) -> None:
|
||||||
"""Test the behaviour when the .well-known delegates elsewhere"""
|
"""Test the behaviour when the .well-known delegates elsewhere"""
|
||||||
self.agent = self._make_agent()
|
self.agent = self._make_agent()
|
||||||
|
|
||||||
@ -802,7 +832,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.well_known_cache.expire()
|
self.well_known_cache.expire()
|
||||||
self.assertNotIn(b"testserv", self.well_known_cache)
|
self.assertNotIn(b"testserv", self.well_known_cache)
|
||||||
|
|
||||||
def test_get_well_known_redirect(self):
|
def test_get_well_known_redirect(self) -> None:
|
||||||
"""Test the behaviour when the server name has no port and no SRV record, but
|
"""Test the behaviour when the server name has no port and no SRV record, but
|
||||||
the .well-known has a 300 redirect
|
the .well-known has a 300 redirect
|
||||||
"""
|
"""
|
||||||
@ -892,7 +922,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.well_known_cache.expire()
|
self.well_known_cache.expire()
|
||||||
self.assertNotIn(b"testserv", self.well_known_cache)
|
self.assertNotIn(b"testserv", self.well_known_cache)
|
||||||
|
|
||||||
def test_get_invalid_well_known(self):
|
def test_get_invalid_well_known(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
|
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
|
||||||
"""
|
"""
|
||||||
@ -945,7 +975,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_well_known_unsigned_cert(self):
|
def test_get_well_known_unsigned_cert(self) -> None:
|
||||||
"""Test the behaviour when the .well-known server presents a cert
|
"""Test the behaviour when the .well-known server presents a cert
|
||||||
not signed by a CA
|
not signed by a CA
|
||||||
"""
|
"""
|
||||||
@ -969,7 +999,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
ip_blacklist=IPSet(),
|
ip_blacklist=IPSet(),
|
||||||
_srv_resolver=self.mock_resolver,
|
_srv_resolver=self.mock_resolver,
|
||||||
_well_known_resolver=WellKnownResolver(
|
_well_known_resolver=WellKnownResolver(
|
||||||
self.reactor,
|
cast(ISynapseReactor, self.reactor),
|
||||||
Agent(self.reactor, contextFactory=tls_factory),
|
Agent(self.reactor, contextFactory=tls_factory),
|
||||||
b"test-agent",
|
b"test-agent",
|
||||||
well_known_cache=self.well_known_cache,
|
well_known_cache=self.well_known_cache,
|
||||||
@ -999,7 +1029,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
b"_matrix._tcp.testserv"
|
b"_matrix._tcp.testserv"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_hostname_srv(self):
|
def test_get_hostname_srv(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the behaviour when there is a single SRV record
|
Test the behaviour when there is a single SRV record
|
||||||
"""
|
"""
|
||||||
@ -1041,7 +1071,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_get_well_known_srv(self):
|
def test_get_well_known_srv(self) -> None:
|
||||||
"""Test the behaviour when the .well-known redirects to a place where there
|
"""Test the behaviour when the .well-known redirects to a place where there
|
||||||
is a SRV.
|
is a SRV.
|
||||||
"""
|
"""
|
||||||
@ -1101,7 +1131,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_idna_servername(self):
|
def test_idna_servername(self) -> None:
|
||||||
"""test the behaviour when the server name has idna chars in"""
|
"""test the behaviour when the server name has idna chars in"""
|
||||||
self.agent = self._make_agent()
|
self.agent = self._make_agent()
|
||||||
|
|
||||||
@ -1163,7 +1193,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_idna_srv_target(self):
|
def test_idna_srv_target(self) -> None:
|
||||||
"""test the behaviour when the target of a SRV record has idna chars"""
|
"""test the behaviour when the target of a SRV record has idna chars"""
|
||||||
self.agent = self._make_agent()
|
self.agent = self._make_agent()
|
||||||
|
|
||||||
@ -1206,7 +1236,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
self.successResultOf(test_d)
|
self.successResultOf(test_d)
|
||||||
|
|
||||||
def test_well_known_cache(self):
|
def test_well_known_cache(self) -> None:
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
fetch_d = defer.ensureDeferred(
|
fetch_d = defer.ensureDeferred(
|
||||||
@ -1262,7 +1292,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
r = self.successResultOf(fetch_d)
|
r = self.successResultOf(fetch_d)
|
||||||
self.assertEqual(r.delegated_server, b"other-server")
|
self.assertEqual(r.delegated_server, b"other-server")
|
||||||
|
|
||||||
def test_well_known_cache_with_temp_failure(self):
|
def test_well_known_cache_with_temp_failure(self) -> None:
|
||||||
"""Test that we refetch well-known before the cache expires, and that
|
"""Test that we refetch well-known before the cache expires, and that
|
||||||
it ignores transient errors.
|
it ignores transient errors.
|
||||||
"""
|
"""
|
||||||
@ -1341,7 +1371,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
r = self.successResultOf(fetch_d)
|
r = self.successResultOf(fetch_d)
|
||||||
self.assertEqual(r.delegated_server, None)
|
self.assertEqual(r.delegated_server, None)
|
||||||
|
|
||||||
def test_well_known_too_large(self):
|
def test_well_known_too_large(self) -> None:
|
||||||
"""A well-known query that returns a result which is too large should be rejected."""
|
"""A well-known query that returns a result which is too large should be rejected."""
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
@ -1367,7 +1397,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
r = self.successResultOf(fetch_d)
|
r = self.successResultOf(fetch_d)
|
||||||
self.assertIsNone(r.delegated_server)
|
self.assertIsNone(r.delegated_server)
|
||||||
|
|
||||||
def test_srv_fallbacks(self):
|
def test_srv_fallbacks(self) -> None:
|
||||||
"""Test that other SRV results are tried if the first one fails."""
|
"""Test that other SRV results are tried if the first one fails."""
|
||||||
self.agent = self._make_agent()
|
self.agent = self._make_agent()
|
||||||
|
|
||||||
@ -1427,7 +1457,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestCachePeriodFromHeaders(unittest.TestCase):
|
class TestCachePeriodFromHeaders(unittest.TestCase):
|
||||||
def test_cache_control(self):
|
def test_cache_control(self) -> None:
|
||||||
# uppercase
|
# uppercase
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
@ -1464,7 +1494,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_expires(self):
|
def test_expires(self) -> None:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
|
Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
|
||||||
@ -1491,14 +1521,14 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
|
|||||||
self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
|
self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
|
||||||
|
|
||||||
|
|
||||||
def _check_logcontext(context):
|
def _check_logcontext(context: LoggingContextOrSentinel) -> None:
|
||||||
current = current_context()
|
current = current_context()
|
||||||
if current is not context:
|
if current is not context:
|
||||||
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
||||||
|
|
||||||
|
|
||||||
def _wrap_server_factory_for_tls(
|
def _wrap_server_factory_for_tls(
|
||||||
factory: IProtocolFactory, sanlist: Iterable[bytes] = None
|
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
|
||||||
) -> IProtocolFactory:
|
) -> IProtocolFactory:
|
||||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||||
The resultant factory will create a TLS server which presents a certificate
|
The resultant factory will create a TLS server which presents a certificate
|
||||||
@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory:
|
|||||||
return server_factory
|
return server_factory
|
||||||
|
|
||||||
|
|
||||||
def _log_request(request: str):
|
def _log_request(request: str) -> None:
|
||||||
"""Implements Factory.log, which is expected by Request.finish"""
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
logger.info(f"Completed request {request}")
|
logger.info(f"Completed request {request}")
|
||||||
|
|
||||||
@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS:
|
|||||||
"""An IPolicyForHTTPS which checks that the certificate belongs to the
|
"""An IPolicyForHTTPS which checks that the certificate belongs to the
|
||||||
right server, but doesn't check the certificate chain."""
|
right server, but doesn't check the certificate chain."""
|
||||||
|
|
||||||
def creatorForNetloc(self, hostname, port):
|
def creatorForNetloc(
|
||||||
|
self, hostname: bytes, port: int
|
||||||
|
) -> IOpenSSLClientConnectionCreator:
|
||||||
certificateOptions = OpenSSLCertificateOptions()
|
certificateOptions = OpenSSLCertificateOptions()
|
||||||
return ClientTLSOptions(hostname, certificateOptions.getContext())
|
return ClientTLSOptions(hostname, certificateOptions.getContext())
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Dict, Generator, List, Tuple, cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred
|
|||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.names import dns, error
|
from twisted.names import dns, error
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import SrvResolver
|
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
||||||
from synapse.logging.context import LoggingContext, current_context
|
from synapse.logging.context import LoggingContext, current_context
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -28,7 +28,7 @@ from tests.utils import MockClock
|
|||||||
|
|
||||||
|
|
||||||
class SrvResolverTestCase(unittest.TestCase):
|
class SrvResolverTestCase(unittest.TestCase):
|
||||||
def test_resolve(self):
|
def test_resolve(self) -> None:
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
|
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
|
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
result_deferred = Deferred()
|
result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
|
||||||
dns_client_mock.lookupService.return_value = result_deferred
|
dns_client_mock.lookupService.return_value = result_deferred
|
||||||
|
|
||||||
cache = {}
|
cache: Dict[bytes, List[Server]] = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_lookup():
|
def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
|
||||||
|
|
||||||
with LoggingContext("one") as ctx:
|
with LoggingContext("one") as ctx:
|
||||||
resolve_d = resolver.resolve_service(service_name)
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
result = yield defer.ensureDeferred(resolve_d)
|
result: List[Server]
|
||||||
|
result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment]
|
||||||
|
|
||||||
# should have restored our context
|
# should have restored our context
|
||||||
self.assertIs(current_context(), ctx)
|
self.assertIs(current_context(), ctx)
|
||||||
@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(servers[0].host, host_name)
|
self.assertEqual(servers[0].host, host_name)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_from_cache_expired_and_dns_fail(self):
|
def test_from_cache_expired_and_dns_fail(
|
||||||
|
self,
|
||||||
|
) -> Generator["Deferred[object]", object, None]:
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
entry.priority = 0
|
entry.priority = 0
|
||||||
entry.weight = 0
|
entry.weight = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [cast(Server, entry)]}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
servers: List[Server]
|
||||||
|
servers = yield defer.ensureDeferred(
|
||||||
|
resolver.resolve_service(service_name)
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(servers, cache[service_name])
|
self.assertEqual(servers, cache[service_name])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_from_cache(self):
|
def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
|
||||||
clock = MockClock()
|
clock = MockClock()
|
||||||
|
|
||||||
dns_client_mock = Mock(spec_set=["lookupService"])
|
dns_client_mock = Mock(spec_set=["lookupService"])
|
||||||
@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
entry.priority = 0
|
entry.priority = 0
|
||||||
entry.weight = 0
|
entry.weight = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [cast(Server, entry)]}
|
||||||
resolver = SrvResolver(
|
resolver = SrvResolver(
|
||||||
dns_client=dns_client_mock, cache=cache, get_time=clock.time
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time
|
||||||
)
|
)
|
||||||
|
|
||||||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
servers: List[Server]
|
||||||
|
servers = yield defer.ensureDeferred(
|
||||||
|
resolver.resolve_service(service_name)
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
self.assertFalse(dns_client_mock.lookupService.called)
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(servers, cache[service_name])
|
self.assertEqual(servers, cache[service_name])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_empty_cache(self):
|
def test_empty_cache(self) -> Generator["Deferred[object]", object, None]:
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||||
|
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache: Dict[bytes, List[Server]] = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
with self.assertRaises(error.DNSServerError):
|
with self.assertRaises(error.DNSServerError):
|
||||||
yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_name_error(self):
|
def test_name_error(self) -> Generator["Deferred[object]", object, None]:
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
|
|
||||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||||
|
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache: Dict[bytes, List[Server]] = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
|
servers: List[Server]
|
||||||
|
servers = yield defer.ensureDeferred(
|
||||||
|
resolver.resolve_service(service_name)
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
self.assertEqual(len(servers), 0)
|
self.assertEqual(len(servers), 0)
|
||||||
self.assertEqual(len(cache), 0)
|
self.assertEqual(len(cache), 0)
|
||||||
|
|
||||||
def test_disabled_service(self):
|
def test_disabled_service(self) -> None:
|
||||||
"""
|
"""
|
||||||
test the behaviour when there is a single record which is ".".
|
test the behaviour when there is a single record which is ".".
|
||||||
"""
|
"""
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
lookup_deferred = Deferred()
|
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache: Dict[bytes, List[Server]] = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
|
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
|
||||||
@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.failureResultOf(resolve_d, ConnectError)
|
self.failureResultOf(resolve_d, ConnectError)
|
||||||
|
|
||||||
def test_non_srv_answer(self):
|
def test_non_srv_answer(self) -> None:
|
||||||
"""
|
"""
|
||||||
test the behaviour when the dns server gives us a spurious non-SRV response
|
test the behaviour when the dns server gives us a spurious non-SRV response
|
||||||
"""
|
"""
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
lookup_deferred = Deferred()
|
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache: Dict[bytes, List[Server]] = {}
|
||||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
# Old versions of Twisted don't have an ensureDeferred in successResultOf.
|
# Old versions of Twisted don't have an ensureDeferred in successResultOf.
|
||||||
|
@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
|
|||||||
return method_name
|
return method_name
|
||||||
|
|
||||||
|
|
||||||
def _hash_stack(stack: List[inspect.FrameInfo]):
|
def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]:
|
||||||
"""Turns a stack into a hashable value that can be put into a set."""
|
"""Turns a stack into a hashable value that can be put into a set."""
|
||||||
return tuple(_format_stack_frame(frame) for frame in stack)
|
return tuple(_format_stack_frame(frame) for frame in stack)
|
||||||
|
@ -11,28 +11,34 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.additional_resource import AdditionalResource
|
from synapse.http.additional_resource import AdditionalResource
|
||||||
from synapse.http.server import respond_with_json
|
from synapse.http.server import respond_with_json
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests.server import FakeSite, make_request
|
from tests.server import FakeSite, make_request
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class _AsyncTestCustomEndpoint:
|
class _AsyncTestCustomEndpoint:
|
||||||
def __init__(self, config, module_api):
|
def __init__(self, config: JsonDict, module_api: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request: Request) -> None:
|
||||||
|
assert isinstance(request, SynapseRequest)
|
||||||
respond_with_json(request, 200, {"some_key": "some_value_async"})
|
respond_with_json(request, 200, {"some_key": "some_value_async"})
|
||||||
|
|
||||||
|
|
||||||
class _SyncTestCustomEndpoint:
|
class _SyncTestCustomEndpoint:
|
||||||
def __init__(self, config, module_api):
|
def __init__(self, config: JsonDict, module_api: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request: Request) -> None:
|
||||||
|
assert isinstance(request, SynapseRequest)
|
||||||
respond_with_json(request, 200, {"some_key": "some_value_sync"})
|
respond_with_json(request, 200, {"some_key": "some_value_sync"})
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||||||
and async handlers.
|
and async handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_async(self):
|
def test_async(self) -> None:
|
||||||
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
||||||
resource = AdditionalResource(self.hs, handler)
|
resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
||||||
|
|
||||||
def test_sync(self):
|
def test_sync(self) -> None:
|
||||||
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
||||||
resource = AdditionalResource(self.hs, handler)
|
resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
|
@ -13,10 +13,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||||
@ -28,6 +30,7 @@ from synapse.http.client import (
|
|||||||
BlacklistingAgentWrapper,
|
BlacklistingAgentWrapper,
|
||||||
BlacklistingReactorWrapper,
|
BlacklistingReactorWrapper,
|
||||||
BodyExceededMaxSize,
|
BodyExceededMaxSize,
|
||||||
|
_DiscardBodyWithMaxSizeProtocol,
|
||||||
read_body_with_max_size,
|
read_body_with_max_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,7 +39,9 @@ from tests.unittest import TestCase
|
|||||||
|
|
||||||
|
|
||||||
class ReadBodyWithMaxSizeTests(TestCase):
|
class ReadBodyWithMaxSizeTests(TestCase):
|
||||||
def _build_response(self, length=UNKNOWN_LENGTH):
|
def _build_response(
|
||||||
|
self, length: Union[int, str] = UNKNOWN_LENGTH
|
||||||
|
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
|
||||||
"""Start reading the body, returns the response, result and proto"""
|
"""Start reading the body, returns the response, result and proto"""
|
||||||
response = Mock(length=length)
|
response = Mock(length=length)
|
||||||
result = BytesIO()
|
result = BytesIO()
|
||||||
@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
|
|
||||||
return result, deferred, protocol
|
return result, deferred, protocol
|
||||||
|
|
||||||
def _assert_error(self, deferred, protocol):
|
def _assert_error(
|
||||||
|
self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
|
||||||
|
) -> None:
|
||||||
"""Ensure that the expected error is received."""
|
"""Ensure that the expected error is received."""
|
||||||
self.assertIsInstance(deferred.result, Failure)
|
assert isinstance(deferred.result, Failure)
|
||||||
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
|
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
|
||||||
protocol.transport.abortConnection.assert_called_once()
|
assert protocol.transport is not None
|
||||||
|
# type-ignore: presumably abortConnection has been replaced with a Mock.
|
||||||
|
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
|
||||||
|
|
||||||
def _cleanup_error(self, deferred):
|
def _cleanup_error(self, deferred: "Deferred[int]") -> None:
|
||||||
"""Ensure that the error in the Deferred is handled gracefully."""
|
"""Ensure that the error in the Deferred is handled gracefully."""
|
||||||
called = [False]
|
called = [False]
|
||||||
|
|
||||||
def errback(f):
|
def errback(f: Failure) -> None:
|
||||||
called[0] = True
|
called[0] = True
|
||||||
|
|
||||||
deferred.addErrback(errback)
|
deferred.addErrback(errback)
|
||||||
self.assertTrue(called[0])
|
self.assertTrue(called[0])
|
||||||
|
|
||||||
def test_no_error(self):
|
def test_no_error(self) -> None:
|
||||||
"""A response that is NOT too large."""
|
"""A response that is NOT too large."""
|
||||||
result, deferred, protocol = self._build_response()
|
result, deferred, protocol = self._build_response()
|
||||||
|
|
||||||
@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
self.assertEqual(result.getvalue(), b"12345")
|
self.assertEqual(result.getvalue(), b"12345")
|
||||||
self.assertEqual(deferred.result, 5)
|
self.assertEqual(deferred.result, 5)
|
||||||
|
|
||||||
def test_too_large(self):
|
def test_too_large(self) -> None:
|
||||||
"""A response which is too large raises an exception."""
|
"""A response which is too large raises an exception."""
|
||||||
result, deferred, protocol = self._build_response()
|
result, deferred, protocol = self._build_response()
|
||||||
|
|
||||||
@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
self._assert_error(deferred, protocol)
|
self._assert_error(deferred, protocol)
|
||||||
self._cleanup_error(deferred)
|
self._cleanup_error(deferred)
|
||||||
|
|
||||||
def test_multiple_packets(self):
|
def test_multiple_packets(self) -> None:
|
||||||
"""Data should be accumulated through mutliple packets."""
|
"""Data should be accumulated through mutliple packets."""
|
||||||
result, deferred, protocol = self._build_response()
|
result, deferred, protocol = self._build_response()
|
||||||
|
|
||||||
@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
self.assertEqual(result.getvalue(), b"1234")
|
self.assertEqual(result.getvalue(), b"1234")
|
||||||
self.assertEqual(deferred.result, 4)
|
self.assertEqual(deferred.result, 4)
|
||||||
|
|
||||||
def test_additional_data(self):
|
def test_additional_data(self) -> None:
|
||||||
"""A connection can receive data after being closed."""
|
"""A connection can receive data after being closed."""
|
||||||
result, deferred, protocol = self._build_response()
|
result, deferred, protocol = self._build_response()
|
||||||
|
|
||||||
@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
self._assert_error(deferred, protocol)
|
self._assert_error(deferred, protocol)
|
||||||
self._cleanup_error(deferred)
|
self._cleanup_error(deferred)
|
||||||
|
|
||||||
def test_content_length(self):
|
def test_content_length(self) -> None:
|
||||||
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
|
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
|
||||||
result, deferred, protocol = self._build_response(length=10)
|
result, deferred, protocol = self._build_response(length=10)
|
||||||
|
|
||||||
@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class BlacklistingAgentTest(TestCase):
|
class BlacklistingAgentTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.reactor, self.clock = get_clock()
|
self.reactor, self.clock = get_clock()
|
||||||
|
|
||||||
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
||||||
@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
|
|||||||
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
||||||
self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
||||||
|
|
||||||
def test_reactor(self):
|
def test_reactor(self) -> None:
|
||||||
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
BlacklistingReactorWrapper(
|
BlacklistingReactorWrapper(
|
||||||
@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase):
|
|||||||
response = self.successResultOf(d)
|
response = self.successResultOf(d)
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
def test_agent(self):
|
def test_agent(self) -> None:
|
||||||
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
||||||
agent = BlacklistingAgentWrapper(
|
agent = BlacklistingAgentWrapper(
|
||||||
Agent(self.reactor),
|
Agent(self.reactor),
|
||||||
|
@ -17,7 +17,7 @@ from tests import unittest
|
|||||||
|
|
||||||
|
|
||||||
class ServerNameTestCase(unittest.TestCase):
|
class ServerNameTestCase(unittest.TestCase):
|
||||||
def test_parse_server_name(self):
|
def test_parse_server_name(self) -> None:
|
||||||
test_data = {
|
test_data = {
|
||||||
"localhost": ("localhost", None),
|
"localhost": ("localhost", None),
|
||||||
"my-example.com:1234": ("my-example.com", 1234),
|
"my-example.com:1234": ("my-example.com", 1234),
|
||||||
@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase):
|
|||||||
for i, o in test_data.items():
|
for i, o in test_data.items():
|
||||||
self.assertEqual(parse_server_name(i), o)
|
self.assertEqual(parse_server_name(i), o)
|
||||||
|
|
||||||
def test_validate_bad_server_names(self):
|
def test_validate_bad_server_names(self) -> None:
|
||||||
test_data = [
|
test_data = [
|
||||||
"", # empty
|
"", # empty
|
||||||
"localhost:http", # non-numeric port
|
"localhost:http", # non-numeric port
|
||||||
|
@ -11,16 +11,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Generator
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import TimeoutError
|
from twisted.internet.defer import Deferred, TimeoutError
|
||||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import MemoryReactor, StringTransport
|
||||||
from twisted.web.client import ResponseNeverReceived
|
from twisted.web.client import ResponseNeverReceived
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import (
|
|||||||
MatrixFederationHttpClient,
|
MatrixFederationHttpClient,
|
||||||
MatrixFederationRequest,
|
MatrixFederationRequest,
|
||||||
)
|
)
|
||||||
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
|
from synapse.logging.context import (
|
||||||
|
SENTINEL_CONTEXT,
|
||||||
|
LoggingContext,
|
||||||
|
LoggingContextOrSentinel,
|
||||||
|
current_context,
|
||||||
|
)
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
def check_logcontext(context):
|
def check_logcontext(context: LoggingContextOrSentinel) -> None:
|
||||||
current = current_context()
|
current = current_context()
|
||||||
if current is not context:
|
if current is not context:
|
||||||
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
||||||
|
|
||||||
|
|
||||||
class FederationClientTests(HomeserverTestCase):
|
class FederationClientTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
|
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
self.cl = MatrixFederationHttpClient(self.hs, None)
|
self.cl = MatrixFederationHttpClient(self.hs, None)
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
def test_client_get(self):
|
def test_client_get(self) -> None:
|
||||||
"""
|
"""
|
||||||
happy-path test of a GET request
|
happy-path test of a GET request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_request():
|
def do_request() -> Generator["Deferred[object]", object, object]:
|
||||||
with LoggingContext("one") as context:
|
with LoggingContext("one") as context:
|
||||||
fetch_d = defer.ensureDeferred(
|
fetch_d = defer.ensureDeferred(
|
||||||
self.cl.get_json("testserv:8008", "foo/bar")
|
self.cl.get_json("testserv:8008", "foo/bar")
|
||||||
@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
# check the response is as expected
|
# check the response is as expected
|
||||||
self.assertEqual(res, {"a": 1})
|
self.assertEqual(res, {"a": 1})
|
||||||
|
|
||||||
def test_dns_error(self):
|
def test_dns_error(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the DNS lookup returns an error, it will bubble up.
|
If the DNS lookup returns an error, it will bubble up.
|
||||||
"""
|
"""
|
||||||
@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
||||||
|
|
||||||
def test_client_connection_refused(self):
|
def test_client_connection_refused(self) -> None:
|
||||||
d = defer.ensureDeferred(
|
d = defer.ensureDeferred(
|
||||||
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
||||||
)
|
)
|
||||||
@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIs(f.value.inner_exception, e)
|
self.assertIs(f.value.inner_exception, e)
|
||||||
|
|
||||||
def test_client_never_connect(self):
|
def test_client_never_connect(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the HTTP request is not connected and is timed out, it'll give a
|
If the HTTP request is not connected and is timed out, it'll give a
|
||||||
ConnectingCancelledError or TimeoutError.
|
ConnectingCancelledError or TimeoutError.
|
||||||
@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
|
f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_client_connect_no_response(self):
|
def test_client_connect_no_response(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the HTTP request is connected, but gets no response before being
|
If the HTTP request is connected, but gets no response before being
|
||||||
timed out, it'll give a ResponseNeverReceived.
|
timed out, it'll give a ResponseNeverReceived.
|
||||||
@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
|
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
|
||||||
|
|
||||||
def test_client_ip_range_blacklist(self):
|
def test_client_ip_range_blacklist(self) -> None:
|
||||||
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
|
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
|
||||||
|
|
||||||
# Set up the ip_range blacklist
|
# Set up the ip_range blacklist
|
||||||
@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
f = self.failureResultOf(d, RequestSendFailed)
|
f = self.failureResultOf(d, RequestSendFailed)
|
||||||
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
|
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
|
||||||
|
|
||||||
def test_client_gets_headers(self):
|
def test_client_gets_headers(self) -> None:
|
||||||
"""
|
"""
|
||||||
Once the client gets the headers, _request returns successfully.
|
Once the client gets the headers, _request returns successfully.
|
||||||
"""
|
"""
|
||||||
@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertEqual(r.code, 200)
|
self.assertEqual(r.code, 200)
|
||||||
|
|
||||||
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
|
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
|
||||||
def test_timeout_reading_body(self, method_name: str):
|
def test_timeout_reading_body(self, method_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
If the HTTP request is connected, but gets no response before being
|
If the HTTP request is connected, but gets no response before being
|
||||||
timed out, it'll give a RequestSendFailed with can_retry.
|
timed out, it'll give a RequestSendFailed with can_retry.
|
||||||
@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertTrue(f.value.can_retry)
|
self.assertTrue(f.value.can_retry)
|
||||||
self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
|
self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
|
||||||
|
|
||||||
def test_client_requires_trailing_slashes(self):
|
def test_client_requires_trailing_slashes(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a connection is made to a client but the client rejects it due to
|
If a connection is made to a client but the client rejects it due to
|
||||||
requiring a trailing slash. We need to retry the request with a
|
requiring a trailing slash. We need to retry the request with a
|
||||||
@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
r = self.successResultOf(d)
|
r = self.successResultOf(d)
|
||||||
self.assertEqual(r, {})
|
self.assertEqual(r, {})
|
||||||
|
|
||||||
def test_client_does_not_retry_on_400_plus(self):
|
def test_client_does_not_retry_on_400_plus(self) -> None:
|
||||||
"""
|
"""
|
||||||
Another test for trailing slashes but now test that we don't retry on
|
Another test for trailing slashes but now test that we don't retry on
|
||||||
trailing slashes on a non-400/M_UNRECOGNIZED response.
|
trailing slashes on a non-400/M_UNRECOGNIZED response.
|
||||||
@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
# We should get a 404 failure response
|
# We should get a 404 failure response
|
||||||
self.failureResultOf(d)
|
self.failureResultOf(d)
|
||||||
|
|
||||||
def test_client_sends_body(self):
|
def test_client_sends_body(self) -> None:
|
||||||
defer.ensureDeferred(
|
defer.ensureDeferred(
|
||||||
self.cl.post_json(
|
self.cl.post_json(
|
||||||
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
|
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
|
||||||
@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
content = request.content.read()
|
content = request.content.read()
|
||||||
self.assertEqual(content, b'{"a":"b"}')
|
self.assertEqual(content, b'{"a":"b"}')
|
||||||
|
|
||||||
def test_closes_connection(self):
|
def test_closes_connection(self) -> None:
|
||||||
"""Check that the client closes unused HTTP connections"""
|
"""Check that the client closes unused HTTP connections"""
|
||||||
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
|
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
|
||||||
|
|
||||||
@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
self.assertTrue(conn.disconnecting)
|
self.assertTrue(conn.disconnecting)
|
||||||
|
|
||||||
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
|
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
|
||||||
def test_json_error(self, return_value):
|
def test_json_error(self, return_value: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Test what happens if invalid JSON is returned from the remote endpoint.
|
Test what happens if invalid JSON is returned from the remote endpoint.
|
||||||
"""
|
"""
|
||||||
@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||||||
f = self.failureResultOf(test_d)
|
f = self.failureResultOf(test_d)
|
||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
|
|
||||||
def test_too_big(self):
|
def test_too_big(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test what happens if a huge response is returned from the remote endpoint.
|
Test what happens if a huge response is returned from the remote endpoint.
|
||||||
"""
|
"""
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Iterable, Optional
|
from typing import List, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import treq
|
import treq
|
||||||
@ -22,7 +22,11 @@ from netaddr import IPSet
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from twisted.internet import interfaces # noqa: F401
|
from twisted.internet import interfaces # noqa: F401
|
||||||
from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
|
from twisted.internet.endpoints import (
|
||||||
|
HostnameEndpoint,
|
||||||
|
_WrapperEndpoint,
|
||||||
|
_WrappingProtocol,
|
||||||
|
)
|
||||||
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
from twisted.internet.interfaces import IProtocol, IProtocolFactory
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
||||||
@ -32,7 +36,11 @@ from synapse.http.client import BlacklistingReactorWrapper
|
|||||||
from synapse.http.connectproxyclient import ProxyCredentials
|
from synapse.http.connectproxyclient import ProxyCredentials
|
||||||
from synapse.http.proxyagent import ProxyAgent, parse_proxy
|
from synapse.http.proxyagent import ProxyAgent, parse_proxy
|
||||||
|
|
||||||
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
|
from tests.http import (
|
||||||
|
TestServerTLSConnectionFactory,
|
||||||
|
dummy_address,
|
||||||
|
get_test_https_policy,
|
||||||
|
)
|
||||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
@ -183,7 +191,7 @@ class ProxyParserTests(TestCase):
|
|||||||
expected_hostname: bytes,
|
expected_hostname: bytes,
|
||||||
expected_port: int,
|
expected_port: int,
|
||||||
expected_credentials: Optional[bytes],
|
expected_credentials: Optional[bytes],
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that a given proxy URL will be broken into the components.
|
Tests that a given proxy URL will be broken into the components.
|
||||||
Args:
|
Args:
|
||||||
@ -209,7 +217,7 @@ class ProxyParserTests(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class MatrixFederationAgentTests(TestCase):
|
class MatrixFederationAgentTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
def _make_connection(
|
def _make_connection(
|
||||||
@ -218,7 +226,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
server_factory: IProtocolFactory,
|
server_factory: IProtocolFactory,
|
||||||
ssl: bool = False,
|
ssl: bool = False,
|
||||||
expected_sni: Optional[bytes] = None,
|
expected_sni: Optional[bytes] = None,
|
||||||
tls_sanlist: Optional[Iterable[bytes]] = None,
|
tls_sanlist: Optional[List[bytes]] = None,
|
||||||
) -> IProtocol:
|
) -> IProtocol:
|
||||||
"""Builds a test server, and completes the outgoing client connection
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
@ -244,7 +252,8 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
if ssl:
|
if ssl:
|
||||||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
|
||||||
|
|
||||||
server_protocol = server_factory.buildProtocol(None)
|
server_protocol = server_factory.buildProtocol(dummy_address)
|
||||||
|
assert server_protocol is not None
|
||||||
|
|
||||||
# now, tell the client protocol factory to build the client protocol,
|
# now, tell the client protocol factory to build the client protocol,
|
||||||
# and wire the output of said protocol up to the server via
|
# and wire the output of said protocol up to the server via
|
||||||
@ -252,7 +261,8 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
#
|
#
|
||||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
# stubbing that out here.
|
# stubbing that out here.
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(dummy_address)
|
||||||
|
assert client_protocol is not None
|
||||||
client_protocol.makeConnection(
|
client_protocol.makeConnection(
|
||||||
FakeTransport(server_protocol, self.reactor, client_protocol)
|
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||||
)
|
)
|
||||||
@ -263,6 +273,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if ssl:
|
if ssl:
|
||||||
|
assert isinstance(server_protocol, TLSMemoryBIOProtocol)
|
||||||
http_protocol = server_protocol.wrappedProtocol
|
http_protocol = server_protocol.wrappedProtocol
|
||||||
tls_connection = server_protocol._tlsConnection
|
tls_connection = server_protocol._tlsConnection
|
||||||
else:
|
else:
|
||||||
@ -288,7 +299,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
scheme: bytes,
|
scheme: bytes,
|
||||||
hostname: bytes,
|
hostname: bytes,
|
||||||
path: bytes,
|
path: bytes,
|
||||||
):
|
) -> None:
|
||||||
"""Runs a test case for a direct connection not going through a proxy.
|
"""Runs a test case for a direct connection not going through a proxy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -319,6 +330,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
ssl=is_https,
|
ssl=is_https,
|
||||||
expected_sni=hostname if is_https else None,
|
expected_sni=hostname if is_https else None,
|
||||||
)
|
)
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
|
|
||||||
# the FakeTransport is async, so we need to pump the reactor
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
@ -339,34 +351,34 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
body = self.successResultOf(treq.content(resp))
|
body = self.successResultOf(treq.content(resp))
|
||||||
self.assertEqual(body, b"result")
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
def test_http_request(self):
|
def test_http_request(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor)
|
agent = ProxyAgent(self.reactor)
|
||||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||||
|
|
||||||
def test_https_request(self):
|
def test_https_request(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||||
|
|
||||||
def test_http_request_use_proxy_empty_environment(self):
|
def test_http_request_use_proxy_empty_environment(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
|
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
|
||||||
def test_http_request_via_uppercase_no_proxy(self):
|
def test_http_request_via_uppercase_no_proxy(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||||
|
|
||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
|
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
|
||||||
)
|
)
|
||||||
def test_http_request_via_no_proxy(self):
|
def test_http_request_via_no_proxy(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||||
|
|
||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
|
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
|
||||||
)
|
)
|
||||||
def test_https_request_via_no_proxy(self):
|
def test_https_request_via_no_proxy(self) -> None:
|
||||||
agent = ProxyAgent(
|
agent = ProxyAgent(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
contextFactory=get_test_https_policy(),
|
contextFactory=get_test_https_policy(),
|
||||||
@ -375,12 +387,12 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
|
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
|
||||||
def test_http_request_via_no_proxy_star(self):
|
def test_http_request_via_no_proxy_star(self) -> None:
|
||||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
|
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
|
||||||
def test_https_request_via_no_proxy_star(self):
|
def test_https_request_via_no_proxy_star(self) -> None:
|
||||||
agent = ProxyAgent(
|
agent = ProxyAgent(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
contextFactory=get_test_https_policy(),
|
contextFactory=get_test_https_policy(),
|
||||||
@ -389,7 +401,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
|
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
|
||||||
def test_http_request_via_proxy(self):
|
def test_http_request_via_proxy(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that requests can be made through a proxy.
|
Tests that requests can be made through a proxy.
|
||||||
"""
|
"""
|
||||||
@ -401,7 +413,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
|
{"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
|
||||||
)
|
)
|
||||||
def test_http_request_via_proxy_with_auth(self):
|
def test_http_request_via_proxy_with_auth(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that authenticated requests can be made through a proxy.
|
Tests that authenticated requests can be made through a proxy.
|
||||||
"""
|
"""
|
||||||
@ -412,7 +424,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
|
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
|
||||||
)
|
)
|
||||||
def test_http_request_via_https_proxy(self):
|
def test_http_request_via_https_proxy(self) -> None:
|
||||||
self._do_http_request_via_proxy(
|
self._do_http_request_via_proxy(
|
||||||
expect_proxy_ssl=True, expected_auth_credentials=None
|
expect_proxy_ssl=True, expected_auth_credentials=None
|
||||||
)
|
)
|
||||||
@ -424,13 +436,13 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
"no_proxy": "unused.com",
|
"no_proxy": "unused.com",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def test_http_request_via_https_proxy_with_auth(self):
|
def test_http_request_via_https_proxy_with_auth(self) -> None:
|
||||||
self._do_http_request_via_proxy(
|
self._do_http_request_via_proxy(
|
||||||
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
|
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
|
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
|
||||||
def test_https_request_via_proxy(self):
|
def test_https_request_via_proxy(self) -> None:
|
||||||
"""Tests that TLS-encrypted requests can be made through a proxy"""
|
"""Tests that TLS-encrypted requests can be made through a proxy"""
|
||||||
self._do_https_request_via_proxy(
|
self._do_https_request_via_proxy(
|
||||||
expect_proxy_ssl=False, expected_auth_credentials=None
|
expect_proxy_ssl=False, expected_auth_credentials=None
|
||||||
@ -440,7 +452,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
|
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
|
||||||
)
|
)
|
||||||
def test_https_request_via_proxy_with_auth(self):
|
def test_https_request_via_proxy_with_auth(self) -> None:
|
||||||
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
|
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
|
||||||
self._do_https_request_via_proxy(
|
self._do_https_request_via_proxy(
|
||||||
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
|
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
|
||||||
@ -449,7 +461,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
|
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
|
||||||
)
|
)
|
||||||
def test_https_request_via_https_proxy(self):
|
def test_https_request_via_https_proxy(self) -> None:
|
||||||
"""Tests that TLS-encrypted requests can be made through a proxy"""
|
"""Tests that TLS-encrypted requests can be made through a proxy"""
|
||||||
self._do_https_request_via_proxy(
|
self._do_https_request_via_proxy(
|
||||||
expect_proxy_ssl=True, expected_auth_credentials=None
|
expect_proxy_ssl=True, expected_auth_credentials=None
|
||||||
@ -459,7 +471,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
os.environ,
|
os.environ,
|
||||||
{"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
|
{"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
|
||||||
)
|
)
|
||||||
def test_https_request_via_https_proxy_with_auth(self):
|
def test_https_request_via_https_proxy_with_auth(self) -> None:
|
||||||
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
|
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
|
||||||
self._do_https_request_via_proxy(
|
self._do_https_request_via_proxy(
|
||||||
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
|
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
|
||||||
@ -469,7 +481,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self,
|
self,
|
||||||
expect_proxy_ssl: bool = False,
|
expect_proxy_ssl: bool = False,
|
||||||
expected_auth_credentials: Optional[bytes] = None,
|
expected_auth_credentials: Optional[bytes] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Send a http request via an agent and check that it is correctly received at
|
"""Send a http request via an agent and check that it is correctly received at
|
||||||
the proxy. The proxy can use either http or https.
|
the proxy. The proxy can use either http or https.
|
||||||
Args:
|
Args:
|
||||||
@ -501,6 +513,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
|
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
|
||||||
expected_sni=b"proxy.com" if expect_proxy_ssl else None,
|
expected_sni=b"proxy.com" if expect_proxy_ssl else None,
|
||||||
)
|
)
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
|
|
||||||
# the FakeTransport is async, so we need to pump the reactor
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
@ -542,7 +555,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self,
|
self,
|
||||||
expect_proxy_ssl: bool = False,
|
expect_proxy_ssl: bool = False,
|
||||||
expected_auth_credentials: Optional[bytes] = None,
|
expected_auth_credentials: Optional[bytes] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Send a https request via an agent and check that it is correctly received at
|
"""Send a https request via an agent and check that it is correctly received at
|
||||||
the proxy and client. The proxy can use either http or https.
|
the proxy and client. The proxy can use either http or https.
|
||||||
Args:
|
Args:
|
||||||
@ -606,10 +619,12 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
# now we make another test server to act as the upstream HTTP server.
|
# now we make another test server to act as the upstream HTTP server.
|
||||||
server_ssl_protocol = _wrap_server_factory_for_tls(
|
server_ssl_protocol = _wrap_server_factory_for_tls(
|
||||||
_get_test_protocol_factory()
|
_get_test_protocol_factory()
|
||||||
).buildProtocol(None)
|
).buildProtocol(dummy_address)
|
||||||
|
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
|
||||||
|
|
||||||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
|
||||||
proxy_server_transport = proxy_server.transport
|
proxy_server_transport = proxy_server.transport
|
||||||
|
assert proxy_server_transport is not None
|
||||||
server_ssl_protocol.makeConnection(proxy_server_transport)
|
server_ssl_protocol.makeConnection(proxy_server_transport)
|
||||||
|
|
||||||
# ... and replace the protocol on the proxy's transport with the
|
# ... and replace the protocol on the proxy's transport with the
|
||||||
@ -644,6 +659,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
|
|
||||||
# now there should be a pending request
|
# now there should be a pending request
|
||||||
http_server = server_ssl_protocol.wrappedProtocol
|
http_server = server_ssl_protocol.wrappedProtocol
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
@ -667,7 +683,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self.assertEqual(body, b"result")
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
|
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
|
||||||
def test_http_request_via_proxy_with_blacklist(self):
|
def test_http_request_via_proxy_with_blacklist(self) -> None:
|
||||||
# The blacklist includes the configured proxy IP.
|
# The blacklist includes the configured proxy IP.
|
||||||
agent = ProxyAgent(
|
agent = ProxyAgent(
|
||||||
BlacklistingReactorWrapper(
|
BlacklistingReactorWrapper(
|
||||||
@ -691,6 +707,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory, _get_test_protocol_factory()
|
client_factory, _get_test_protocol_factory()
|
||||||
)
|
)
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
|
|
||||||
# the FakeTransport is async, so we need to pump the reactor
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
@ -712,7 +729,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self.assertEqual(body, b"result")
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
|
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
|
||||||
def test_https_request_via_uppercase_proxy_with_blacklist(self):
|
def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
|
||||||
# The blacklist includes the configured proxy IP.
|
# The blacklist includes the configured proxy IP.
|
||||||
agent = ProxyAgent(
|
agent = ProxyAgent(
|
||||||
BlacklistingReactorWrapper(
|
BlacklistingReactorWrapper(
|
||||||
@ -737,11 +754,15 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
proxy_server = self._make_connection(
|
proxy_server = self._make_connection(
|
||||||
client_factory, _get_test_protocol_factory()
|
client_factory, _get_test_protocol_factory()
|
||||||
)
|
)
|
||||||
|
assert isinstance(proxy_server, HTTPChannel)
|
||||||
|
|
||||||
# fish the transports back out so that we can do the old switcheroo
|
# fish the transports back out so that we can do the old switcheroo
|
||||||
s2c_transport = proxy_server.transport
|
s2c_transport = proxy_server.transport
|
||||||
|
assert isinstance(s2c_transport, FakeTransport)
|
||||||
client_protocol = s2c_transport.other
|
client_protocol = s2c_transport.other
|
||||||
|
assert isinstance(client_protocol, _WrappingProtocol)
|
||||||
c2s_transport = client_protocol.transport
|
c2s_transport = client_protocol.transport
|
||||||
|
assert isinstance(c2s_transport, FakeTransport)
|
||||||
|
|
||||||
# the FakeTransport is async, so we need to pump the reactor
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
@ -762,8 +783,10 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
|
|
||||||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
||||||
ssl_protocol = ssl_factory.buildProtocol(None)
|
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
|
||||||
|
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
|
||||||
http_server = ssl_protocol.wrappedProtocol
|
http_server = ssl_protocol.wrappedProtocol
|
||||||
|
assert isinstance(http_server, HTTPChannel)
|
||||||
|
|
||||||
ssl_protocol.makeConnection(
|
ssl_protocol.makeConnection(
|
||||||
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
||||||
@ -797,28 +820,28 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
self.assertEqual(body, b"result")
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
|
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
|
||||||
def test_proxy_with_no_scheme(self):
|
def test_proxy_with_no_scheme(self) -> None:
|
||||||
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
|
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
|
||||||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
|
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
|
||||||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
|
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
|
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
|
||||||
def test_proxy_with_unsupported_scheme(self):
|
def test_proxy_with_unsupported_scheme(self) -> None:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ProxyAgent(self.reactor, use_proxy=True)
|
ProxyAgent(self.reactor, use_proxy=True)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
|
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
|
||||||
def test_proxy_with_http_scheme(self):
|
def test_proxy_with_http_scheme(self) -> None:
|
||||||
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
|
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
|
||||||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
|
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
|
||||||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
|
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
|
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
|
||||||
def test_proxy_with_https_scheme(self):
|
def test_proxy_with_https_scheme(self) -> None:
|
||||||
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||||
self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
|
assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
|
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
|
||||||
)
|
)
|
||||||
@ -828,7 +851,7 @@ class MatrixFederationAgentTests(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def _wrap_server_factory_for_tls(
|
def _wrap_server_factory_for_tls(
|
||||||
factory: IProtocolFactory, sanlist: Iterable[bytes] = None
|
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
|
||||||
) -> IProtocolFactory:
|
) -> IProtocolFactory:
|
||||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||||
|
|
||||||
@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory:
|
|||||||
return server_factory
|
return server_factory
|
||||||
|
|
||||||
|
|
||||||
def _log_request(request: str):
|
def _log_request(request: str) -> None:
|
||||||
"""Implements Factory.log, which is expected by Request.finish"""
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
logger.info(f"Completed request {request}")
|
logger.info(f"Completed request {request}")
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
import json
|
import json
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Tuple
|
from typing import Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
@ -33,7 +33,7 @@ from tests import unittest
|
|||||||
from tests.http.server._base import test_disconnect
|
from tests.http.server._base import test_disconnect
|
||||||
|
|
||||||
|
|
||||||
def make_request(content):
|
def make_request(content: Union[bytes, JsonDict]) -> Mock:
|
||||||
"""Make an object that acts enough like a request."""
|
"""Make an object that acts enough like a request."""
|
||||||
request = Mock(spec=["method", "uri", "content"])
|
request = Mock(spec=["method", "uri", "content"])
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def make_request(content):
|
|||||||
|
|
||||||
|
|
||||||
class TestServletUtils(unittest.TestCase):
|
class TestServletUtils(unittest.TestCase):
|
||||||
def test_parse_json_value(self):
|
def test_parse_json_value(self) -> None:
|
||||||
"""Basic tests for parse_json_value_from_request."""
|
"""Basic tests for parse_json_value_from_request."""
|
||||||
# Test round-tripping.
|
# Test round-tripping.
|
||||||
obj = {"foo": 1}
|
obj = {"foo": 1}
|
||||||
@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase):
|
|||||||
with self.assertRaises(SynapseError):
|
with self.assertRaises(SynapseError):
|
||||||
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
|
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
|
||||||
|
|
||||||
def test_parse_json_object(self):
|
def test_parse_json_object(self) -> None:
|
||||||
"""Basic tests for parse_json_object_from_request."""
|
"""Basic tests for parse_json_object_from_request."""
|
||||||
# Test empty.
|
# Test empty.
|
||||||
result = parse_json_object_from_request(
|
result = parse_json_object_from_request(
|
||||||
|
@ -17,22 +17,24 @@ from netaddr import IPSet
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.http import RequestTimedOutError
|
from synapse.http import RequestTimedOutError
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class SimpleHttpClientTests(HomeserverTestCase):
|
class SimpleHttpClientTests(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs: "HomeServer"):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None:
|
||||||
# Add a DNS entry for a test server
|
# Add a DNS entry for a test server
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
||||||
self.cl = hs.get_simple_http_client()
|
self.cl = hs.get_simple_http_client()
|
||||||
|
|
||||||
def test_dns_error(self):
|
def test_dns_error(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the DNS lookup returns an error, it will bubble up.
|
If the DNS lookup returns an error, it will bubble up.
|
||||||
"""
|
"""
|
||||||
@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
|
|||||||
f = self.failureResultOf(d)
|
f = self.failureResultOf(d)
|
||||||
self.assertIsInstance(f.value, DNSLookupError)
|
self.assertIsInstance(f.value, DNSLookupError)
|
||||||
|
|
||||||
def test_client_connection_refused(self):
|
def test_client_connection_refused(self) -> None:
|
||||||
d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
|
d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
|
||||||
|
|
||||||
self.pump()
|
self.pump()
|
||||||
@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertIs(f.value, e)
|
self.assertIs(f.value, e)
|
||||||
|
|
||||||
def test_client_never_connect(self):
|
def test_client_never_connect(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the HTTP request is not connected and is timed out, it'll give a
|
If the HTTP request is not connected and is timed out, it'll give a
|
||||||
ConnectingCancelledError or TimeoutError.
|
ConnectingCancelledError or TimeoutError.
|
||||||
@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(f.value, RequestTimedOutError)
|
self.assertIsInstance(f.value, RequestTimedOutError)
|
||||||
|
|
||||||
def test_client_connect_no_response(self):
|
def test_client_connect_no_response(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the HTTP request is connected, but gets no response before being
|
If the HTTP request is connected, but gets no response before being
|
||||||
timed out, it'll give a ResponseNeverReceived.
|
timed out, it'll give a ResponseNeverReceived.
|
||||||
@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(f.value, RequestTimedOutError)
|
self.assertIsInstance(f.value, RequestTimedOutError)
|
||||||
|
|
||||||
def test_client_ip_range_blacklist(self):
|
def test_client_ip_range_blacklist(self) -> None:
|
||||||
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
|
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
|
||||||
|
|
||||||
# Add some DNS entries we'll blacklist
|
# Add some DNS entries we'll blacklist
|
||||||
|
@ -13,18 +13,20 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet.address import IPv6Address
|
from twisted.internet.address import IPv6Address
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import MemoryReactor, StringTransport
|
||||||
|
|
||||||
from synapse.app.homeserver import SynapseHomeServer
|
from synapse.app.homeserver import SynapseHomeServer
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
class SynapseRequestTestCase(HomeserverTestCase):
|
class SynapseRequestTestCase(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
|
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
|
||||||
|
|
||||||
def test_large_request(self):
|
def test_large_request(self) -> None:
|
||||||
"""overlarge HTTP requests should be rejected"""
|
"""overlarge HTTP requests should be rejected"""
|
||||||
self.hs.start_listening()
|
self.hs.start_listening()
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ from synapse.logging.context import ContextResourceUsage
|
|||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.engines import PostgresEngine, create_engine
|
from synapse.storage.engines import PostgresEngine, create_engine
|
||||||
from synapse.types import JsonDict
|
from synapse.types import ISynapseReactor, JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.utils import (
|
from tests.utils import (
|
||||||
@ -401,7 +401,9 @@ def make_request(
|
|||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
@implementer(IReactorPluggableNameResolver)
|
# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
|
||||||
|
# marking this as an implementer of the latter seems to keep mypy-zope happier.
|
||||||
|
@implementer(IReactorPluggableNameResolver, ISynapseReactor)
|
||||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
"""
|
"""
|
||||||
A MemoryReactorClock that supports callFromThread.
|
A MemoryReactorClock that supports callFromThread.
|
||||||
|
Loading…
Reference in New Issue
Block a user