Merge pull request #212 from matrix-org/erikj/cache_deferreds

Make CacheDescriptor cache deferreds rather than the deferreds' values
This commit is contained in:
Erik Johnston 2015-08-07 19:28:05 +01:00
commit 06218ab125
3 changed files with 47 additions and 19 deletions

View File

@ -15,6 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -131,6 +132,9 @@ class Cache(object):
class CacheDescriptor(object): class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache; a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value. misses use the function body to generate the value.
@ -173,13 +177,16 @@ class CacheDescriptor(object):
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try: try:
cached_result = cache.get(*keyargs) cached_result_d = cache.get(*keyargs)
observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs) actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
@ -189,17 +196,30 @@ class CacheDescriptor(object):
) )
raise ValueError("Stale cache entry") raise ValueError("Stale cache entry")
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = cache.sequence
ret = yield self.function_to_call(obj, *args, **kwargs) ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
cache.invalidate(*keyargs)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=False)
cache.update(sequence, *(keyargs + [ret])) cache.update(sequence, *(keyargs + [ret]))
defer.returnValue(ret) return ret.observe()
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all

View File

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View File

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached from synapse.storage._base import Cache, cached
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
d = defer.succeed(123)
class A(object): class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return d
a = A() a = A()
a.func.prefill("foo", 123) a.func.prefill("foo", ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)