Fix some error cases in the caching layer. (#5749)

There was some inconsistent behaviour in the caching layer around how
exceptions were handled - particularly synchronously-thrown ones.

This seems to be most easily handled by pushing the creation of
ObservableDeferreds down from CacheDescriptor to the Cache.
This commit is contained in:
Richard van der Hoff 2019-07-25 15:59:45 +01:00 committed by GitHub
parent f16aa3a44b
commit 618bd1ee76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 35 deletions

1
changelog.d/5749.misc Normal file
View File

@ -0,0 +1 @@
Fix some error cases in the caching layer.

View File

@ -19,8 +19,7 @@ import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
import six from six import itervalues
from six import itervalues, string_types
from prometheus_client import Gauge from prometheus_client import Gauge
@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache from . import register_cache
@ -124,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics update_metrics (bool): whether to update the cache hit rate metrics
Returns: Returns:
Either a Deferred or the raw result Either an ObservableDeferred or the raw result
""" """
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel) val = self._pending_deferred_cache.get(key, _CacheSentinel)
@ -148,9 +146,14 @@ class Cache(object):
return default return default
def set(self, key, value, callback=None): def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks) observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry: if existing_entry:
@ -158,11 +161,16 @@ class Cache(object):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def shuffle(result): def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry: if existing_entry is entry:
self.cache.set(key, result, entry.callbacks) return True
else:
# oops, the _pending_deferred_cache has been updated since # oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date. # we started our query, so we are out of date.
# #
@ -172,6 +180,12 @@ class Cache(object):
if existing_entry is not None: if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need # we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called. # to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was # That was probably done when _pending_deferred_cache was
@ -179,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may # `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now. # not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
return result
entry.deferred.addCallback(shuffle) def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None): def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr) ret.addErrback(onErr)
# If our cache_key is a string on py2, try to convert to ascii result_d = cache.set(cache_key, ret, callback=invalidate_callback)
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe() observer = result_d.observe()
if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer) return make_deferred_yieldable(observer)
else:
return observer
if self.num_args == 1: if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg) missing.add(arg)
if missing: if missing:
# we need an observable deferred for each entry in the list, # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the # which we put in the cache. Each deferred resolves with the
# relevant result for that key. # relevant result for that key.
deferreds_map = {} deferreds_map = {}
@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred() deferred = defer.Deferred()
deferreds_map[arg] = deferred deferreds_map[arg] = deferred
key = arg_to_cache_key(arg) key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred) cache.set(key, deferred, callback=invalidate_callback)
cache.set(key, observable, callback=invalidate_callback)
def complete_all(res): def complete_all(res):
# the wrapped function has completed. It returns a # the wrapped function has completed. It returns a

View File

@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
from tests import unittest from tests import unittest
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred() d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1)) cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds # lookup should return observable deferreds
self.assertIs(cache.get("key1"), d1) self.assertFalse(cache.get("key1").has_called())
self.assertIs(cache.get("key2"), d2) self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete # let one of the lookups complete
d2.callback("result2") d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2") self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation # now do the invalidation
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips") self.assertEqual(r, "chips")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_logcontexts(self): def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when """Check that logcontexts are set and restored correctly when
using the cache.""" using the cache."""
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1) self.assertEqual(LoggingContext.current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls() obj = Cls()
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips") self.assertEqual(r, "chips")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
class CachedListDescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks