Add concept of cache contexts

This commit is contained in:
Erik Johnston 2016-08-19 11:18:26 +01:00
parent 5674ea3e6c
commit 4161ff2fc4
5 changed files with 278 additions and 20 deletions

View File

@ -55,7 +55,7 @@ class Cache(object):
) )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru: if True:
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type
@ -81,8 +81,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel) val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -94,19 +94,19 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value): def update(self, sequence, key, value, callback=None):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value) self.prefill(key, value, callback=callback)
def prefill(self, key, value): def prefill(self, key, value, callback=None):
if self.max_entries is not None: if self.max_entries is not None:
while len(self.cache) >= self.max_entries: while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False) self.cache.popitem(last=False, callback=None)
self.cache[key] = value self.cache.set(key, value, callback=callback)
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -151,6 +151,18 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill", The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks()
def foo(self, key, cache_context):
r1 = yield self.bar1(key, cache_context=cache_context)
r2 = yield self.bar2(key, cache_context=cache_context)
defer.returnValue(r1 + r2)
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False): inlineCallbacks=False):
@ -168,7 +180,13 @@ class CacheDescriptor(object):
self.lru = lru self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in self.arg_names:
self.arg_names.remove("cache_context")
self.add_cache_context = "cache_context" in all_args.args
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
@ -188,10 +206,23 @@ class CacheDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
cache_context = kwargs.pop("cache_context", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
self_context = _CacheContext(cache, None)
if self.add_cache_context:
kwargs["cache_context"] = self_context
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
self_context.key = cache_key
try: try:
cached_result_d = cache.get(cache_key) cached_result_d = cache.get(cache_key, callback=context_callback)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -228,7 +259,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret, callback=context_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -297,6 +328,12 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
cache_context = kwargs.pop("cache_context", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key)) res = cache.get(tuple(key), callback=context_callback)
if not res.has_succeeded(): if not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer) cache.update(
sequence, tuple(key), observer,
callback=context_callback
)
def invalidate(f, key): def invalidate(f, key):
cache.invalidate(key) cache.invalidate(key)
@ -376,6 +416,17 @@ class CacheListDescriptor(object):
return wrapped return wrapped
class _CacheContext(object):
__slots__ = ["cache", "key"]
def __init__(self, cache, key):
self.cache = cache
self.key = key
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, lru=True, tree=False): def cached(max_entries=1000, num_args=1, lru=True, tree=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,

View File

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object): class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value): def __init__(self, prev_node, next_node, key, value, callbacks=[]):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks
class LruCache(object): class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache. Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner return inner
def add_node(key, value): def add_node(key, value, callbacks=[]):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value) node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks = []
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback:
node.callbacks.append(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value): def cache_set(key, value, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks = []
if callback:
node.callbacks.append(callback)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
add_node(key, value) if callback:
callbacks = [callback]
else:
callbacks = []
add_node(key, value, callbacks)
if len(cache) > max_size: if len(cache) > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node
delete_node(todelete) delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear(): def cache_clear():
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear() cache.clear()
@synchronized @synchronized

View File

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt self.size -= cnt
return popped return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self): def __len__(self):
return self.size return self.size

View File

@ -199,3 +199,69 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached()
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, cache_context=cache_context)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached()
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, cache_context=cache_context)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)

View File

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
@ -79,3 +81,114 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)