diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ca929bc23..247dd1569 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -134,7 +134,7 @@ class PushRuleStore(SQLBaseStore): return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, cache_context): # We don't use `state_group`, its there so that we can cache based diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index c38f01ead..e7a74d3da 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -146,7 +146,7 @@ class CacheDescriptor(object): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks() + @cachedInlineCallbacks(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, cache_context=cache_context) r2 = yield self.bar2(key, cache_context=cache_context) @@ -154,7 +154,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False): + inlineCallbacks=False, cache_context=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -171,15 +171,28 @@ class CacheDescriptor(object): all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] - if "cache_context" in self.arg_names: - self.arg_names.remove("cache_context") + if "cache_context" in all_args.args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + try: + self.arg_names.remove("cache_context") + except ValueError: + pass + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) - self.add_cache_context = "cache_context" in all_args.args + self.add_cache_context = cache_context if len(self.arg_names) < self.num_args: raise Exception( "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" + " (@cached cannot key off of *args or **kwargs)" % (orig.__name__,) ) @@ -193,12 +206,16 @@ class CacheDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated cache_context = kwargs.pop("cache_context", None) if cache_context: context_callback = cache_context.invalidate else: context_callback = None + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one self_context = _CacheContext(cache, None) if self.add_cache_context: kwargs["cache_context"] = self_context @@ -414,22 +431,24 @@ class _CacheContext(object): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, + cache_context=cache_context, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, inlineCallbacks=True, + cache_context=cache_context, ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ed074ce9e..eab0c8d21 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -211,7 +211,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount[0] += 1 return key - @cached() + @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 return self.func(key, cache_context=cache_context) @@ -244,7 +244,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount[0] += 1 return key - @cached() + @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 return self.func(key, cache_context=cache_context)