Optimise caches with single key

This commit is contained in:
Erik Johnston 2017-05-04 14:18:46 +01:00
parent 5d8290429c
commit d2d8ed4884

View File

@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache from . import register_cache
@ -163,10 +164,6 @@ class Cache(object):
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable, iterable=self.iterable,
) )
def get_cache_key(args, kwargs): def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into """Given some args/kwargs return a generator that resolves into
the cache_key. the cache_key.
@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase):
else: else:
yield self.arg_defaults[nm] yield self.arg_defaults[nm]
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]
def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
cache_key = tuple(get_cache_key(args, kwargs)) cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function # Add our own `cache_context` to argument list if the wrapped function
# has asked for one # has asked for one
@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr) ret.addErrback(onErr)
# If our cache_key is a string, try to convert to ascii to save
# a bit of space in large caches
if isinstance(cache_key, basestring):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True) result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback) cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe() observer = result_d.observe()
@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase):
else: else:
return observer return observer
wrapped.invalidate = cache.invalidate if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.cache = cache wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped