mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Fix some error cases in the caching layer. (#5749)
There was some inconsistent behaviour in the caching layer around how exceptions were handled - particularly synchronously-thrown ones. This seems to be most easily handled by pushing the creation of ObservableDeferreds down from CacheDescriptor to the Cache.
This commit is contained in:
parent
f16aa3a44b
commit
618bd1ee76
1
changelog.d/5749.misc
Normal file
1
changelog.d/5749.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix some error cases in the caching layer.
|
@ -19,8 +19,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import six
|
from six import itervalues
|
||||||
from six import itervalues, string_types
|
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
|
|||||||
from synapse.util.caches import get_cache_factor_for
|
from synapse.util.caches import get_cache_factor_for
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
from synapse.util.stringutils import to_ascii
|
|
||||||
|
|
||||||
from . import register_cache
|
from . import register_cache
|
||||||
|
|
||||||
@ -124,7 +122,7 @@ class Cache(object):
|
|||||||
update_metrics (bool): whether to update the cache hit rate metrics
|
update_metrics (bool): whether to update the cache hit rate metrics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either a Deferred or the raw result
|
Either an ObservableDeferred or the raw result
|
||||||
"""
|
"""
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
||||||
@ -148,9 +146,14 @@ class Cache(object):
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
def set(self, key, value, callback=None):
|
def set(self, key, value, callback=None):
|
||||||
|
if not isinstance(value, defer.Deferred):
|
||||||
|
raise TypeError("not a Deferred")
|
||||||
|
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
entry = CacheEntry(deferred=value, callbacks=callbacks)
|
observable = ObservableDeferred(value, consumeErrors=True)
|
||||||
|
observer = defer.maybeDeferred(observable.observe)
|
||||||
|
entry = CacheEntry(deferred=observable, callbacks=callbacks)
|
||||||
|
|
||||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||||
if existing_entry:
|
if existing_entry:
|
||||||
@ -158,20 +161,31 @@ class Cache(object):
|
|||||||
|
|
||||||
self._pending_deferred_cache[key] = entry
|
self._pending_deferred_cache[key] = entry
|
||||||
|
|
||||||
def shuffle(result):
|
def compare_and_pop():
|
||||||
|
"""Check if our entry is still the one in _pending_deferred_cache, and
|
||||||
|
if so, pop it.
|
||||||
|
|
||||||
|
Returns true if the entries matched.
|
||||||
|
"""
|
||||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||||
if existing_entry is entry:
|
if existing_entry is entry:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# oops, the _pending_deferred_cache has been updated since
|
||||||
|
# we started our query, so we are out of date.
|
||||||
|
#
|
||||||
|
# Better put back whatever we took out. (We do it this way
|
||||||
|
# round, rather than peeking into the _pending_deferred_cache
|
||||||
|
# and then removing on a match, to make the common case faster)
|
||||||
|
if existing_entry is not None:
|
||||||
|
self._pending_deferred_cache[key] = existing_entry
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cb(result):
|
||||||
|
if compare_and_pop():
|
||||||
self.cache.set(key, result, entry.callbacks)
|
self.cache.set(key, result, entry.callbacks)
|
||||||
else:
|
else:
|
||||||
# oops, the _pending_deferred_cache has been updated since
|
|
||||||
# we started our query, so we are out of date.
|
|
||||||
#
|
|
||||||
# Better put back whatever we took out. (We do it this way
|
|
||||||
# round, rather than peeking into the _pending_deferred_cache
|
|
||||||
# and then removing on a match, to make the common case faster)
|
|
||||||
if existing_entry is not None:
|
|
||||||
self._pending_deferred_cache[key] = existing_entry
|
|
||||||
|
|
||||||
# we're not going to put this entry into the cache, so need
|
# we're not going to put this entry into the cache, so need
|
||||||
# to make sure that the invalidation callbacks are called.
|
# to make sure that the invalidation callbacks are called.
|
||||||
# That was probably done when _pending_deferred_cache was
|
# That was probably done when _pending_deferred_cache was
|
||||||
@ -179,9 +193,16 @@ class Cache(object):
|
|||||||
# `invalidate` being previously called, in which case it may
|
# `invalidate` being previously called, in which case it may
|
||||||
# not have been. Either way, let's double-check now.
|
# not have been. Either way, let's double-check now.
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
return result
|
|
||||||
|
|
||||||
entry.deferred.addCallback(shuffle)
|
def eb(_fail):
|
||||||
|
compare_and_pop()
|
||||||
|
entry.invalidate()
|
||||||
|
|
||||||
|
# once the deferred completes, we can move the entry from the
|
||||||
|
# _pending_deferred_cache to the real cache.
|
||||||
|
#
|
||||||
|
observer.addCallbacks(cb, eb)
|
||||||
|
return observable
|
||||||
|
|
||||||
def prefill(self, key, value, callback=None):
|
def prefill(self, key, value, callback=None):
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||||||
|
|
||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
# If our cache_key is a string on py2, try to convert to ascii
|
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||||
# to save a bit of space in large caches. Py3 does this
|
|
||||||
# internally automatically.
|
|
||||||
if six.PY2 and isinstance(cache_key, string_types):
|
|
||||||
cache_key = to_ascii(cache_key)
|
|
||||||
|
|
||||||
result_d = ObservableDeferred(ret, consumeErrors=True)
|
|
||||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
|
||||||
observer = result_d.observe()
|
observer = result_d.observe()
|
||||||
|
|
||||||
if isinstance(observer, defer.Deferred):
|
return make_deferred_yieldable(observer)
|
||||||
return make_deferred_yieldable(observer)
|
|
||||||
else:
|
|
||||||
return observer
|
|
||||||
|
|
||||||
if self.num_args == 1:
|
if self.num_args == 1:
|
||||||
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
||||||
@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
missing.add(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
# we need an observable deferred for each entry in the list,
|
# we need a deferred for each entry in the list,
|
||||||
# which we put in the cache. Each deferred resolves with the
|
# which we put in the cache. Each deferred resolves with the
|
||||||
# relevant result for that key.
|
# relevant result for that key.
|
||||||
deferreds_map = {}
|
deferreds_map = {}
|
||||||
@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
deferred = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
deferreds_map[arg] = deferred
|
deferreds_map[arg] = deferred
|
||||||
key = arg_to_cache_key(arg)
|
key = arg_to_cache_key(arg)
|
||||||
observable = ObservableDeferred(deferred)
|
cache.set(key, deferred, callback=invalidate_callback)
|
||||||
cache.set(key, observable, callback=invalidate_callback)
|
|
||||||
|
|
||||||
def complete_all(res):
|
def complete_all(res):
|
||||||
# the wrapped function has completed. It returns a
|
# the wrapped function has completed. It returns a
|
||||||
|
@ -27,6 +27,7 @@ from synapse.logging.context import (
|
|||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
)
|
)
|
||||||
from synapse.util.caches import descriptors
|
from synapse.util.caches import descriptors
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
|
|||||||
d2 = defer.Deferred()
|
d2 = defer.Deferred()
|
||||||
cache.set("key2", d2, partial(record_callback, 1))
|
cache.set("key2", d2, partial(record_callback, 1))
|
||||||
|
|
||||||
# lookup should return the deferreds
|
# lookup should return observable deferreds
|
||||||
self.assertIs(cache.get("key1"), d1)
|
self.assertFalse(cache.get("key1").has_called())
|
||||||
self.assertIs(cache.get("key2"), d2)
|
self.assertFalse(cache.get("key2").has_called())
|
||||||
|
|
||||||
# let one of the lookups complete
|
# let one of the lookups complete
|
||||||
d2.callback("result2")
|
d2.callback("result2")
|
||||||
|
|
||||||
|
# for now at least, the cache will return real results rather than an
|
||||||
|
# observabledeferred
|
||||||
self.assertEqual(cache.get("key2"), "result2")
|
self.assertEqual(cache.get("key2"), "result2")
|
||||||
|
|
||||||
# now do the invalidation
|
# now do the invalidation
|
||||||
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(r, "chips")
|
self.assertEqual(r, "chips")
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_with_sync_exception(self):
|
||||||
|
"""If the wrapped function throws synchronously, things should continue to work
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
raise SynapseError(100, "mai spoon iz too big!!1")
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# this should fail immediately
|
||||||
|
d = obj.fn(1)
|
||||||
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
# ... leaving the cache empty
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||||
|
|
||||||
|
# and a second call should result in a second exception
|
||||||
|
d = obj.fn(1)
|
||||||
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
def test_cache_logcontexts(self):
|
def test_cache_logcontexts(self):
|
||||||
"""Check that logcontexts are set and restored correctly when
|
"""Check that logcontexts are set and restored correctly when
|
||||||
using the cache."""
|
using the cache."""
|
||||||
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(LoggingContext.current_context(), c1)
|
self.assertEqual(LoggingContext.current_context(), c1)
|
||||||
|
|
||||||
|
# the cache should now be empty
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||||
|
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
|
|
||||||
# set off a deferred which will do a cache lookup
|
# set off a deferred which will do a cache lookup
|
||||||
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(r, "chips")
|
self.assertEqual(r, "chips")
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_iterable(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached(iterable=True)
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
obj.mock.return_value = ["spam", "eggs"]
|
||||||
|
r = obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, ["spam", "eggs"])
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = ["chips"]
|
||||||
|
r = obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, ["chips"])
|
||||||
|
obj.mock.assert_called_once_with(1, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 3)
|
||||||
|
|
||||||
|
r = obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, ["spam", "eggs"])
|
||||||
|
r = obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, ["chips"])
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_cache_iterable_with_sync_exception(self):
|
||||||
|
"""If the wrapped function throws synchronously, things should continue to work
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
@descriptors.cached(iterable=True)
|
||||||
|
def fn(self, arg1):
|
||||||
|
raise SynapseError(100, "mai spoon iz too big!!1")
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
# this should fail immediately
|
||||||
|
d = obj.fn(1)
|
||||||
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
# ... leaving the cache empty
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||||
|
|
||||||
|
# and a second call should result in a second exception
|
||||||
|
d = obj.fn(1)
|
||||||
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
|
||||||
class CachedListDescriptorTestCase(unittest.TestCase):
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
Loading…
Reference in New Issue
Block a user