mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Ensure invalidation list does not grow unboundedly
This commit is contained in:
parent
c0d7d9d642
commit
45fd2c8942
@ -25,6 +25,7 @@ from synapse.util.logcontext import (
|
|||||||
from . import DEBUG_CACHES, register_cache
|
from . import DEBUG_CACHES, register_cache
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import functools
|
import functools
|
||||||
@ -210,16 +211,17 @@ class CacheDescriptor(object):
|
|||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
# Add our own `cache_context` to argument list if the wrapped function
|
# Add temp cache_context so inspect.getcallargs doesn't explode
|
||||||
# has asked for one
|
|
||||||
self_context = _CacheContext(cache, None)
|
|
||||||
if self.add_cache_context:
|
if self.add_cache_context:
|
||||||
kwargs["cache_context"] = self_context
|
kwargs["cache_context"] = None
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
|
|
||||||
self_context.key = cache_key
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
|
# has asked for one
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = _CacheContext(cache, cache_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
||||||
@ -414,13 +416,7 @@ class CacheListDescriptor(object):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class _CacheContext(object):
|
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||||
__slots__ = ["cache", "key"]
|
|
||||||
|
|
||||||
def __init__(self, cache, key):
|
|
||||||
self.cache = cache
|
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def invalidate(self):
|
def invalidate(self):
|
||||||
self.cache.invalidate(self.key)
|
self.cache.invalidate(self.key)
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ def enumerate_leaves(node, depth):
|
|||||||
class _Node(object):
|
class _Node(object):
|
||||||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
||||||
|
|
||||||
def __init__(self, prev_node, next_node, key, value, callbacks=[]):
|
def __init__(self, prev_node, next_node, key, value, callbacks=set()):
|
||||||
self.prev_node = prev_node
|
self.prev_node = prev_node
|
||||||
self.next_node = next_node
|
self.next_node = next_node
|
||||||
self.key = key
|
self.key = key
|
||||||
@ -66,7 +66,7 @@ class LruCache(object):
|
|||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def add_node(key, value, callbacks=[]):
|
def add_node(key, value, callbacks=set()):
|
||||||
prev_node = list_root
|
prev_node = list_root
|
||||||
next_node = prev_node.next_node
|
next_node = prev_node.next_node
|
||||||
node = _Node(prev_node, next_node, key, value, callbacks)
|
node = _Node(prev_node, next_node, key, value, callbacks)
|
||||||
@ -94,7 +94,7 @@ class LruCache(object):
|
|||||||
|
|
||||||
for cb in node.callbacks:
|
for cb in node.callbacks:
|
||||||
cb()
|
cb()
|
||||||
node.callbacks = []
|
node.callbacks.clear()
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_get(key, default=None, callback=None):
|
def cache_get(key, default=None, callback=None):
|
||||||
@ -102,7 +102,7 @@ class LruCache(object):
|
|||||||
if node is not None:
|
if node is not None:
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
if callback:
|
if callback:
|
||||||
node.callbacks.append(callback)
|
node.callbacks.add(callback)
|
||||||
return node.value
|
return node.value
|
||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
@ -114,18 +114,18 @@ class LruCache(object):
|
|||||||
if value != node.value:
|
if value != node.value:
|
||||||
for cb in node.callbacks:
|
for cb in node.callbacks:
|
||||||
cb()
|
cb()
|
||||||
node.callbacks = []
|
node.callbacks.clear()
|
||||||
|
|
||||||
if callback:
|
if callback:
|
||||||
node.callbacks.append(callback)
|
node.callbacks.add(callback)
|
||||||
|
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
node.value = value
|
node.value = value
|
||||||
else:
|
else:
|
||||||
if callback:
|
if callback:
|
||||||
callbacks = [callback]
|
callbacks = set([callback])
|
||||||
else:
|
else:
|
||||||
callbacks = []
|
callbacks = set()
|
||||||
add_node(key, value, callbacks)
|
add_node(key, value, callbacks)
|
||||||
if len(cache) > max_size:
|
if len(cache) > max_size:
|
||||||
todelete = list_root.prev_node
|
todelete = list_root.prev_node
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from synapse.util.caches.descriptors import Cache, cached
|
from synapse.util.caches.descriptors import Cache, cached
|
||||||
@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEquals(callcount[0], 4)
|
self.assertEquals(callcount[0], 4)
|
||||||
self.assertEquals(callcount2[0], 3)
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_double_get(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A(object):
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(cache.get("key"), 1)
|
self.assertEquals(cache.get("key"), 1)
|
||||||
self.assertEquals(cache.setdefault("key", 2), 1)
|
self.assertEquals(cache.setdefault("key", 2), 1)
|
||||||
self.assertEquals(cache.get("key"), 1)
|
self.assertEquals(cache.get("key"), 1)
|
||||||
|
cache["key"] = 2 # Make sure overriding works.
|
||||||
|
self.assertEquals(cache.get("key"), 2)
|
||||||
|
|
||||||
def test_pop(self):
|
def test_pop(self):
|
||||||
cache = LruCache(1)
|
cache = LruCache(1)
|
||||||
@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class LruCacheCallbacksTestCase(unittest.TestCase):
|
class LruCacheCallbacksTestCase(unittest.TestCase):
|
||||||
|
def test_get(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value2")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
def test_multi_get(self):
|
||||||
|
m = Mock()
|
||||||
|
cache = LruCache(1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.get("key", callback=m)
|
||||||
|
self.assertFalse(m.called)
|
||||||
|
|
||||||
|
cache.set("key", "value2")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
|
cache.set("key", "value")
|
||||||
|
self.assertEquals(m.call_count, 1)
|
||||||
|
|
||||||
def test_set(self):
|
def test_set(self):
|
||||||
m = Mock()
|
m = Mock()
|
||||||
cache = LruCache(1)
|
cache = LruCache(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user