mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge pull request #212 from matrix-org/erikj/cache_deferreds
Make CacheDescriptor cache deferreds rather than the deferreds' values
This commit is contained in:
commit
06218ab125
@ -15,6 +15,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.lrucache import LruCache
|
from synapse.util.lrucache import LruCache
|
||||||
@ -131,6 +132,9 @@ class Cache(object):
|
|||||||
class CacheDescriptor(object):
|
class CacheDescriptor(object):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
fail are removed from the cache.
|
||||||
|
|
||||||
The function is presumed to take zero or more arguments, which are used in
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||||
misses use the function body to generate the value.
|
misses use the function body to generate the value.
|
||||||
@ -173,13 +177,16 @@ class CacheDescriptor(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
@defer.inlineCallbacks
|
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
try:
|
try:
|
||||||
cached_result = cache.get(*keyargs)
|
cached_result_d = cache.get(*keyargs)
|
||||||
|
|
||||||
|
observer = cached_result_d.observe()
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_result(cached_result):
|
||||||
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||||
if actual_result != cached_result:
|
if actual_result != cached_result:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -189,17 +196,30 @@ class CacheDescriptor(object):
|
|||||||
)
|
)
|
||||||
raise ValueError("Stale cache entry")
|
raise ValueError("Stale cache entry")
|
||||||
defer.returnValue(cached_result)
|
defer.returnValue(cached_result)
|
||||||
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
|
return observer
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# Get the sequence number of the cache before reading from the
|
# Get the sequence number of the cache before reading from the
|
||||||
# database so that we can tell if the cache is invalidated
|
# database so that we can tell if the cache is invalidated
|
||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = cache.sequence
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = yield self.function_to_call(obj, *args, **kwargs)
|
ret = defer.maybeDeferred(
|
||||||
|
self.function_to_call,
|
||||||
|
obj, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def onErr(f):
|
||||||
|
cache.invalidate(*keyargs)
|
||||||
|
return f
|
||||||
|
|
||||||
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
|
ret = ObservableDeferred(ret, consumeErrors=False)
|
||||||
cache.update(sequence, *(keyargs + [ret]))
|
cache.update(sequence, *(keyargs + [ret]))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
return ret.observe()
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
|
@ -51,7 +51,7 @@ class ObservableDeferred(object):
|
|||||||
object.__setattr__(self, "_observers", set())
|
object.__setattr__(self, "_observers", set())
|
||||||
|
|
||||||
def callback(r):
|
def callback(r):
|
||||||
self._result = (True, r)
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().callback(r)
|
self._observers.pop().callback(r)
|
||||||
@ -60,7 +60,7 @@ class ObservableDeferred(object):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
def errback(f):
|
def errback(f):
|
||||||
self._result = (False, f)
|
object.__setattr__(self, "_result", (False, f))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().errback(f)
|
self._observers.pop().errback(f)
|
||||||
@ -97,3 +97,8 @@ class ObservableDeferred(object):
|
|||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
setattr(self._deferred, name, value)
|
setattr(self._deferred, name, value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
|
||||||
|
id(self), self._result, self._deferred,
|
||||||
|
)
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from synapse.storage._base import Cache, cached
|
from synapse.storage._base import Cache, cached
|
||||||
|
|
||||||
|
|
||||||
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(callcount[0] >= 14,
|
self.assertTrue(callcount[0] >= 14,
|
||||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_prefill(self):
|
def test_prefill(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
|
d = defer.succeed(123)
|
||||||
|
|
||||||
class A(object):
|
class A(object):
|
||||||
@cached()
|
@cached()
|
||||||
def func(self, key):
|
def func(self, key):
|
||||||
callcount[0] += 1
|
callcount[0] += 1
|
||||||
return key
|
return d
|
||||||
|
|
||||||
a = A()
|
a = A()
|
||||||
|
|
||||||
a.func.prefill("foo", 123)
|
a.func.prefill("foo", ObservableDeferred(d))
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), 123)
|
self.assertEquals(a.func("foo").result, d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user