Add stop_cancellation utility function (#12106)

This commit is contained in:
Sean Quah 2022-03-01 13:51:03 +00:00 committed by GitHub
parent c893632319
commit 91bc15c772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 0 deletions

1
changelog.d/12106.misc Normal file
View File

@ -0,0 +1 @@
Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled.

View File

@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
return value return value
return DoneAwaitable(value) return DoneAwaitable(value)
def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
Args:
deferred: The `Deferred` to protect against cancellation. Must not follow the
Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`,
but will not propagate cancellation through to the original. When cancelled,
the new `Deferred` will fail with a `CancelledError` and will not follow the
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
the new `Deferred`.
"""
new_deferred: defer.Deferred[T] = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred

View File

@ -27,6 +27,7 @@ from synapse.logging.context import (
from synapse.util.async_helpers import ( from synapse.util.async_helpers import (
ObservableDeferred, ObservableDeferred,
concurrently_execute, concurrently_execute,
stop_cancellation,
timeout_deferred, timeout_deferred,
) )
@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase):
d2 = ensureDeferred(caller()) d2 = ensureDeferred(caller())
d1.callback(0) d1.callback(0)
self.successResultOf(d2) self.successResultOf(d2)
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
def test_succeed(self):
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
# Success should propagate through.
deferred.callback("success")
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
def test_failure(self):
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
# Failure should propagate through.
deferred.errback(ValueError("abc"))
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, ValueError)
self.assertIsNone(deferred.result, "`Failure` was not consumed")
def test_cancellation(self):
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled."
)
# Now make the inner `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")