diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e..7fdd84bbd 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb..1f71773aa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1, inlineCallbacks=True) def are_guests(self, user_ids): sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f940601..c5d2a3a6d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19f..758f5982b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks,