Add concept of cache contexts

This commit is contained in:
Erik Johnston 2016-08-19 11:18:26 +01:00
parent 5674ea3e6c
commit 4161ff2fc4
5 changed files with 278 additions and 20 deletions

View File

@ -55,7 +55,7 @@ class Cache(object):
)
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
if True:
cache_type = TreeCache if tree else dict
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
@ -81,8 +81,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -94,19 +94,19 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value):
def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
self.prefill(key, value, callback=callback)
def prefill(self, key, value):
def prefill(self, key, value, callback=None):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache.popitem(last=False, callback=None)
self.cache[key] = value
self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@ -151,6 +151,18 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks()
def foo(self, key, cache_context):
r1 = yield self.bar1(key, cache_context=cache_context)
r2 = yield self.bar2(key, cache_context=cache_context)
defer.returnValue(r1 + r2)
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False):
@ -168,7 +180,13 @@ class CacheDescriptor(object):
self.lru = lru
self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in self.arg_names:
self.arg_names.remove("cache_context")
self.add_cache_context = "cache_context" in all_args.args
if len(self.arg_names) < self.num_args:
raise Exception(
@ -188,10 +206,23 @@ class CacheDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
cache_context = kwargs.pop("cache_context", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
self_context = _CacheContext(cache, None)
if self.add_cache_context:
kwargs["cache_context"] = self_context
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
self_context.key = cache_key
try:
cached_result_d = cache.get(cache_key)
cached_result_d = cache.get(cache_key, callback=context_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@ -228,7 +259,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret)
cache.update(sequence, cache_key, ret, callback=context_callback)
return preserve_context_over_deferred(ret.observe())
@ -297,6 +328,12 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
cache_context = kwargs.pop("cache_context", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
res = cache.get(tuple(key))
res = cache.get(tuple(key), callback=context_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer)
cache.update(
sequence, tuple(key), observer,
callback=context_callback
)
def invalidate(f, key):
cache.invalidate(key)
@ -376,6 +416,17 @@ class CacheListDescriptor(object):
return wrapped
class _CacheContext(object):
__slots__ = ["cache", "key"]
def __init__(self, cache, key):
self.cache = cache
self.key = key
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, lru=True, tree=False):
return lambda orig: CacheDescriptor(
orig,

View File

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"]
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value):
def __init__(self, prev_node, next_node, key, value, callbacks=[]):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
self.callbacks = callbacks
class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner
def add_node(key, value):
def add_node(key, value, callbacks=[]):
prev_node = list_root
next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value)
node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks = []
@synchronized
def cache_get(key, default=None):
def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.append(callback)
return node.value
else:
return default
@synchronized
def cache_set(key, value):
def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks = []
if callback:
node.callbacks.append(callback)
move_node_to_front(node)
node.value = value
else:
add_node(key, value)
if callback:
callbacks = [callback]
else:
callbacks = []
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear():
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear()
@synchronized

View File

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self):
return self.size

View File

@ -199,3 +199,69 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached()
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, cache_context=cache_context)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached()
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, cache_context=cache_context)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)

View File

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase):
@ -79,3 +81,114 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)