Update delay_cancellation to accept any awaitable (#12468)

This will mainly be useful when dealing with module callbacks, which are
all typed as returning `Awaitable`s instead of coroutines or
`Deferred`s.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-04-22 18:20:06 +01:00 committed by GitHub
parent b82fff66df
commit a50fb411b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 14 deletions

View file

@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
def test_cancellation(self):
def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
async def task():
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
# tracebacks will be printed in logs.
raise ValueError("abc")
wrapper_deferred = delay_cancellation(task())
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
blocking_deferred.called, "Cancellation was propagated too deep"
)
self.assertFalse(completion_deferred.called)
# Unblock the task.
blocking_deferred.callback(None)
self.assertTrue(completion_deferred.called)
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed.
@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
async def outer():
with LoggingContext("c") as c:
try:
await delay_cancellation(defer.ensureDeferred(inner()))
await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())