mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add cancellation support to @cached
and @cachedList
decorators (#12183)
These decorators mostly support cancellation already. Add cancellation tests and fix use of finished logging contexts by delaying cancellation, as suggested by @erikjohnston. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
605d161d7d
commit
2fcf4b3f6c
1
changelog.d/12183.misc
Normal file
1
changelog.d/12183.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add cancellation support to `@cached` and `@cachedList` decorators.
|
@ -41,6 +41,7 @@ from twisted.python.failure import Failure
|
|||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.async_helpers import delay_cancellation
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||||
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||||
|
|
||||||
|
# We started a new call to `self.orig`, so we must always wait for it to
|
||||||
|
# complete. Otherwise we might mark our current logging context as
|
||||||
|
# finished while `self.orig` is still using it in the background.
|
||||||
|
ret = delay_cancellation(ret)
|
||||||
|
|
||||||
return make_deferred_yieldable(ret)
|
return make_deferred_yieldable(ret)
|
||||||
|
|
||||||
wrapped = cast(_CachedFunction, _wrapped)
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||||||
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
||||||
lambda _: results, unwrapFirstError
|
lambda _: results, unwrapFirstError
|
||||||
)
|
)
|
||||||
|
if missing:
|
||||||
|
# We started a new call to `self.orig`, so we must always wait for it to
|
||||||
|
# complete. Otherwise we might mark our current logging context as
|
||||||
|
# finished while `self.orig` is still using it in the background.
|
||||||
|
d = delay_cancellation(d)
|
||||||
return make_deferred_yieldable(d)
|
return make_deferred_yieldable(d)
|
||||||
else:
|
else:
|
||||||
return defer.succeed(results)
|
return defer.succeed(results)
|
||||||
|
@ -17,7 +17,7 @@ from typing import Set
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import CancelledError, Deferred
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
@ -28,7 +28,7 @@ from synapse.logging.context import (
|
|||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
)
|
)
|
||||||
from synapse.util.caches import descriptors
|
from synapse.util.caches import descriptors
|
||||||
from synapse.util.caches.descriptors import cached, lru_cache
|
from synapse.util.caches.descriptors import cached, cachedList, lru_cache
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import get_awaitable_result
|
from tests.test_utils import get_awaitable_result
|
||||||
@ -493,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
|
|||||||
obj.invalidate()
|
obj.invalidate()
|
||||||
top_invalidate.assert_called_once()
|
top_invalidate.assert_called_once()
|
||||||
|
|
||||||
|
def test_cancel(self):
|
||||||
|
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||||
|
complete_lookup: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
@cached()
|
||||||
|
async def fn(self, arg1):
|
||||||
|
await complete_lookup
|
||||||
|
return str(arg1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
d1 = obj.fn(123)
|
||||||
|
d2 = obj.fn(123)
|
||||||
|
self.assertFalse(d1.called)
|
||||||
|
self.assertFalse(d2.called)
|
||||||
|
|
||||||
|
# Cancel `d1`, which is the lookup that caused `fn` to run.
|
||||||
|
d1.cancel()
|
||||||
|
|
||||||
|
# `d2` should complete normally.
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
self.failureResultOf(d1, CancelledError)
|
||||||
|
self.assertEqual(d2.result, "123")
|
||||||
|
|
||||||
|
def test_cancel_logcontexts(self):
|
||||||
|
"""Test that cancellation does not break logcontexts.
|
||||||
|
|
||||||
|
* The `CancelledError` must be raised with the correct logcontext.
|
||||||
|
* The inner lookup must not resume with a finished logcontext.
|
||||||
|
* The inner lookup must not restore a finished logcontext when done.
|
||||||
|
"""
|
||||||
|
complete_lookup: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
inner_context_was_finished = False
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def fn(self, arg1):
|
||||||
|
await make_deferred_yieldable(complete_lookup)
|
||||||
|
self.inner_context_was_finished = current_context().finished
|
||||||
|
return str(arg1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
async def do_lookup():
|
||||||
|
with LoggingContext("c1") as c1:
|
||||||
|
try:
|
||||||
|
await obj.fn(123)
|
||||||
|
self.fail("No CancelledError thrown")
|
||||||
|
except CancelledError:
|
||||||
|
self.assertEqual(
|
||||||
|
current_context(),
|
||||||
|
c1,
|
||||||
|
"CancelledError was not raised with the correct logcontext",
|
||||||
|
)
|
||||||
|
# suppress the error and succeed
|
||||||
|
|
||||||
|
d = defer.ensureDeferred(do_lookup())
|
||||||
|
d.cancel()
|
||||||
|
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
self.successResultOf(d)
|
||||||
|
self.assertFalse(
|
||||||
|
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||||
|
)
|
||||||
|
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||||
"""More tests for @cached
|
"""More tests for @cached
|
||||||
@ -865,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
|||||||
obj.fn.invalidate((10, 2))
|
obj.fn.invalidate((10, 2))
|
||||||
invalidate0.assert_called_once()
|
invalidate0.assert_called_once()
|
||||||
invalidate1.assert_called_once()
|
invalidate1.assert_called_once()
|
||||||
|
|
||||||
|
def test_cancel(self):
|
||||||
|
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||||
|
complete_lookup: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
@cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@cachedList(cached_method_name="fn", list_name="args")
|
||||||
|
async def list_fn(self, args):
|
||||||
|
await complete_lookup
|
||||||
|
return {arg: str(arg) for arg in args}
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
d1 = obj.list_fn([123, 456])
|
||||||
|
d2 = obj.list_fn([123, 456, 789])
|
||||||
|
self.assertFalse(d1.called)
|
||||||
|
self.assertFalse(d2.called)
|
||||||
|
|
||||||
|
d1.cancel()
|
||||||
|
|
||||||
|
# `d2` should complete normally.
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
self.failureResultOf(d1, CancelledError)
|
||||||
|
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
|
||||||
|
|
||||||
|
def test_cancel_logcontexts(self):
|
||||||
|
"""Test that cancellation does not break logcontexts.
|
||||||
|
|
||||||
|
* The `CancelledError` must be raised with the correct logcontext.
|
||||||
|
* The inner lookup must not resume with a finished logcontext.
|
||||||
|
* The inner lookup must not restore a finished logcontext when done.
|
||||||
|
"""
|
||||||
|
complete_lookup: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
inner_context_was_finished = False
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@cachedList(cached_method_name="fn", list_name="args")
|
||||||
|
async def list_fn(self, args):
|
||||||
|
await make_deferred_yieldable(complete_lookup)
|
||||||
|
self.inner_context_was_finished = current_context().finished
|
||||||
|
return {arg: str(arg) for arg in args}
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
async def do_lookup():
|
||||||
|
with LoggingContext("c1") as c1:
|
||||||
|
try:
|
||||||
|
await obj.list_fn([123])
|
||||||
|
self.fail("No CancelledError thrown")
|
||||||
|
except CancelledError:
|
||||||
|
self.assertEqual(
|
||||||
|
current_context(),
|
||||||
|
c1,
|
||||||
|
"CancelledError was not raised with the correct logcontext",
|
||||||
|
)
|
||||||
|
# suppress the error and succeed
|
||||||
|
|
||||||
|
d = defer.ensureDeferred(do_lookup())
|
||||||
|
d.cancel()
|
||||||
|
|
||||||
|
complete_lookup.callback(None)
|
||||||
|
self.successResultOf(d)
|
||||||
|
self.assertFalse(
|
||||||
|
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||||
|
)
|
||||||
|
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||||
|
Loading…
Reference in New Issue
Block a user