diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1607978e2..eed60d567 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -197,6 +197,7 @@ class _CacheDescriptorBase(object): arg_spec = inspect.getargspec(orig) all_args = arg_spec.args + self.arg_spec = arg_spec if "cache_context" in all_args: if not cache_context: @@ -226,6 +227,14 @@ class _CacheDescriptorBase(object): self.num_args = num_args self.arg_names = all_args[1:num_args + 1] + if arg_spec.defaults: + self.arg_defaults = dict(zip( + all_args[-len(arg_spec.defaults):], + arg_spec.defaults + )) + else: + self.arg_defaults = {} + if "cache_context" in self.arg_names: raise Exception( "cache_context arg cannot be included among the cache keys" @@ -289,18 +298,31 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) + def get_cache_key(args, kwargs): + """Given some args/kwargs return a generator that resolves into + the cache_key. + + We loop through each arg name, looking up if its in the `kwargs`, + otherwise using the next argument in `args`. If there are no more + args then we try looking the arg name up in the defaults + """ + pos = 0 + for nm in self.arg_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield self.arg_defaults[nm] + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add temp cache_context so inspect.getcallargs doesn't explode - if self.add_cache_context: - kwargs["cache_context"] = None - - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = tuple(get_cache_key(args, kwargs)) # Add our own `cache_context` to argument list if the wrapped function # has asked for one