Respect the @cancellable flag for RestServlets and BaseFederationServlets (#12699)

Both `RestServlet`s and `BaseFederationServlet`s register their handlers
with `HttpServer.register_paths` / `JsonResource.register_paths`. Update
`JsonResource` to respect the `@cancellable` flag on handlers registered
in this way.

Although `ReplicationEndpoint` also registers itself using
`register_paths`, it does not pass the handler method that would have the
`@cancellable` flag directly, and so needs separate handling.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-05-11 12:25:13 +01:00 committed by GitHub
parent dffecade7d
commit 9d8e380d2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 191 additions and 2 deletions

View file

@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from http import HTTPStatus
from io import BytesIO
from typing import Tuple
from unittest.mock import Mock
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import cancellable
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_json_value_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
def make_request(content):
@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))
class CancellableRestServlet(RestServlet):
"""A `RestServlet` with a mix of cancellable and uncancellable handlers."""
PATTERNS = client_patterns("/sleep$")
def __init__(self, hs: HomeServer):
super().__init__()
self.clock = hs.get_clock()
@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
class TestRestServletCancellation(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `RestServlet` cancellation."""
servlets = [
lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
]
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
)
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)