Clean up the test code for client disconnections (#12929)

* Reword failure message about `await_result=False`
* Use `reactor.advance()` instead of `reactor.pump()`
* Raise `AssertionError`s ourselves
* Un-instance method `_test_disconnect`
* Replace `ThreadedMemoryReactorClock` with `MemoryReactorClock`
This commit is contained in:
Sean Quah 2022-06-07 18:17:32 +01:00 committed by GitHub
parent 586bfc6dc0
commit 3c1c40d843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 87 deletions

1
changelog.d/12929.misc Normal file
View File

@ -0,0 +1 @@
Clean up the test code for client disconnection.

View File

@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.http.server._base import test_disconnect
class CancellableFederationServlet(BaseFederationServlet): class CancellableFederationServlet(BaseFederationServlet):
@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
return HTTPStatus.OK, {"result": True} return HTTPStatus.OK, {"result": True}
class BaseFederationServletCancellationTests( class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `BaseFederationServlet` cancellation.""" """Tests for `BaseFederationServlet` cancellation."""
skip = "`BaseFederationServlet` does not support cancellation yet." skip = "`BaseFederationServlet` does not support cancellation yet."
@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed. # request won't be processed.
self.pump() self.pump()
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=True, expect_cancellation=True,
@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed. # request won't be processed.
self.pump() self.pump()
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=False, expect_cancellation=False,

View File

@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict from synapse.types import JsonDict
from tests import unittest from tests.server import FakeChannel, make_request
from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
from tests.unittest import logcontext_clean from tests.unittest import logcontext_clean
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
class EndpointCancellationTestHelperMixin(unittest.TestCase): def test_disconnect(
"""Provides helper methods for testing cancellation of endpoints.""" reactor: MemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.
def _test_disconnect( Args:
self, reactor: The twisted reactor running the request handler.
reactor: ThreadedMemoryReactorClock, channel: The `FakeChannel` for the request.
channel: FakeChannel, expect_cancellation: `True` if request processing is expected to be cancelled,
expect_cancellation: bool, `False` if the request should run to completion.
expected_body: Union[bytes, JsonDict], expected_body: The expected response for the request.
expected_code: Optional[int] = None, expected_code: The expected status code for the request. Defaults to `200` or
) -> None: `499` depending on `expect_cancellation`.
"""Disconnects an in-flight request and checks the response. """
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK
Args: request = channel.request
reactor: The twisted reactor running the request handler. if channel.is_finished():
channel: The `FakeChannel` for the request. raise AssertionError(
expect_cancellation: `True` if request processing is expected to be
cancelled, `False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200`
or `499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK
request = channel.request
self.assertFalse(
channel.is_finished(),
"Request finished before we could disconnect - " "Request finished before we could disconnect - "
"was `await_result=False` passed to `make_request`?", "ensure `await_result=False` is passed to `make_request`.",
) )
# We're about to disconnect the request. This also disconnects the channel, so # We're about to disconnect the request. This also disconnects the channel, so we
# we have to rely on mocks to extract the response. # have to rely on mocks to extract the response.
respond_method: Callable[..., Any] respond_method: Callable[..., Any]
if isinstance(expected_body, bytes): if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes respond_method = respond_with_html_bytes
else:
respond_method = respond_with_json
with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())
if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
else: else:
respond_method = respond_with_json respond_mock.assert_not_called()
with mock.patch( # The handler is expected to run to completion.
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method reactor.advance(1.0)
) as respond_mock: respond_mock.assert_called_once()
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())
if expect_cancellation: args, _kwargs = respond_mock.call_args
# An immediate cancellation is expected. code, body = args[1], args[2]
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
else:
respond_mock.assert_not_called()
# The handler is expected to run to completion. if code != expected_code:
reactor.pump([1.0]) raise AssertionError(
respond_mock.assert_called_once() f"{code} != {expected_code} : "
args, _kwargs = respond_mock.call_args "Request did not finish with the expected status code."
code, body = args[1], args[2] )
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code) if request.code != expected_code:
self.assertEqual(body, expected_body) raise AssertionError(
f"{request.code} != {expected_code} : "
"Request did not finish with the expected status code."
)
if body != expected_body:
raise AssertionError(
f"{body!r} != {expected_body!r} : "
"Request did not finish with the expected status code."
)
@logcontext_clean @logcontext_clean

View File

@ -30,7 +30,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.http.server._base import test_disconnect
def make_request(content): def make_request(content):
@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True} return HTTPStatus.OK, {"result": True}
class TestRestServletCancellation( class TestRestServletCancellation(unittest.HomeserverTestCase):
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `RestServlet` cancellation.""" """Tests for `RestServlet` cancellation."""
servlets = [ servlets = [
@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None: def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled.""" """Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False) channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=True, expect_cancellation=True,
@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None: def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled.""" """Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False) channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=False, expect_cancellation=False,

View File

@ -25,7 +25,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint): class CancellableReplicationEndpoint(ReplicationEndpoint):
@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True} return HTTPStatus.OK, {"result": True}
class ReplicationEndpointCancellationTestCase( class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `ReplicationEndpoint` cancellation.""" """Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self): def create_test_resource(self):
@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled.""" """Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False) channel = self.make_request("POST", path, await_result=False)
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=True, expect_cancellation=True,
@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled.""" """Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False) channel = self.make_request("POST", path, await_result=False)
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=False, expect_cancellation=False,

View File

@ -34,7 +34,7 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.http.server._base import test_disconnect
from tests.server import ( from tests.server import (
FakeSite, FakeSite,
ThreadedMemoryReactorClock, ThreadedMemoryReactorClock,
@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
return HTTPStatus.OK, b"ok" return HTTPStatus.OK, b"ok"
class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation.""" """Tests for `DirectServeJsonResource` cancellation."""
def setUp(self): def setUp(self):
@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request( channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False self.reactor, self.site, "GET", "/sleep", await_result=False
) )
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=True, expect_cancellation=True,
@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request( channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False self.reactor, self.site, "POST", "/sleep", await_result=False
) )
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=False, expect_cancellation=False,
@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
) )
class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation.""" """Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self): def setUp(self):
@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request( channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False self.reactor, self.site, "GET", "/sleep", await_result=False
) )
self._test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
expect_cancellation=True, expect_cancellation=True,
@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request( channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False self.reactor, self.site, "POST", "/sleep", await_result=False
) )
self._test_disconnect( test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok" self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
) )