Update delay_cancellation to accept any awaitable (#12468)

This will mainly be useful when dealing with module callbacks, which are
all typed as returning `Awaitable`s instead of coroutines or
`Deferred`s.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-04-22 18:20:06 +01:00 committed by GitHub
parent b82fff66df
commit a50fb411b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 14 deletions

1
changelog.d/12468.misc Normal file
View File

@ -0,0 +1 @@
Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.

View File

@ -41,7 +41,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal from typing_extensions import Literal
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
@ -794,7 +793,7 @@ class DatabasePool:
# We also wait until everything above is done before releasing the # We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been # `CancelledError`, so that logging contexts won't get used after they have been
# finished. # finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction())) return await delay_cancellation(_runInteraction())
async def runWithConnection( async def runWithConnection(
self, self,

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import abc import abc
import asyncio
import collections import collections
import inspect import inspect
import itertools import itertools
@ -25,6 +26,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Collection, Collection,
Coroutine,
Dict, Dict,
Generic, Generic,
Hashable, Hashable,
@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
return new_deferred return new_deferred
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": @overload
"""Delay cancellation of a `Deferred` until it resolves. def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
...
@overload
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
...
@overload
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
...
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves. resolve with a `CancelledError` until the original awaitable resolves.
Args: Args:
deferred: The `Deferred` to protect against cancellation. May optionally follow deferred: The coroutine or `Deferred` to protect against cancellation. May
the Synapse logcontext rules. optionally follow the Synapse logcontext rules.
Returns: Returns:
A new `Deferred`, which will contain the result of the original `Deferred`. A new `Deferred`, which will contain the result of the original coroutine or
The new `Deferred` will not propagate cancellation through to the original. `Deferred`. The new `Deferred` will not propagate cancellation through to the
When cancelled, the new `Deferred` will wait until the original `Deferred` original coroutine or `Deferred`.
resolves before failing with a `CancelledError`.
The new `Deferred` will follow the Synapse logcontext rules if `deferred` When cancelled, the new `Deferred` will wait until the original coroutine or
`Deferred` resolves before failing with a `CancelledError`.
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`. wrapped with `make_deferred_yieldable`.
""" """
# First, convert the awaitable into a `Deferred`.
if isinstance(awaitable, defer.Deferred):
deferred = awaitable
elif asyncio.iscoroutine(awaitable):
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
# type-checking, but we'd need Twisted >= 21.2.
deferred = defer.ensureDeferred(awaitable)
else:
# We have no idea what to do with this awaitable.
# We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
# `make_awaitable`, and let the caller `await` it normally.
return awaitable
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation # before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to # propagating. we then `unpause` it once the wrapped deferred completes, to

View File

@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase): class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function.""" """Tests for the `delay_cancellation` function."""
def test_cancellation(self): def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original.""" """Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred) wrapper_deferred = delay_cancellation(deferred)
@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`. # Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError) self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
async def task():
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
# tracebacks will be printed in logs.
raise ValueError("abc")
wrapper_deferred = delay_cancellation(task())
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
blocking_deferred.called, "Cancellation was propagated too deep"
)
self.assertFalse(completion_deferred.called)
# Unblock the task.
blocking_deferred.callback(None)
self.assertTrue(completion_deferred.called)
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self): def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed. """Test that a second cancellation is suppressed.
@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
async def outer(): async def outer():
with LoggingContext("c") as c: with LoggingContext("c") as c:
try: try:
await delay_cancellation(defer.ensureDeferred(inner())) await delay_cancellation(inner())
self.fail("`CancelledError` was not raised") self.fail("`CancelledError` was not raised")
except CancelledError: except CancelledError:
self.assertEqual(c, current_context()) self.assertEqual(c, current_context())