diff --git a/synapse/util/async.py b/synapse/util/async.py index 0d6f48e2d..40be7fe7e 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -102,6 +102,15 @@ class ObservableDeferred(object): def observers(self): return self._observers + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + def __getattr__(self, name): return getattr(self._deferred, name) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5d25c9e76..f31dfb22b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -33,6 +33,7 @@ import functools import inspect import threading + logger = logging.getLogger(__name__) @@ -302,16 +303,21 @@ class CacheListDescriptor(object): # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) - cached = {} + results = {} + cached_defers = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: - res = cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + res = cache.get(tuple(key)) + if not res.has_succeeded(): + res = res.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached_defers[arg] = res + else: + results[arg] = res.get_result() except KeyError: missing.append(arg) @@ -349,12 +355,21 @@ class CacheListDescriptor(object): res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + cached_defers[arg] = res - return preserve_context_over_deferred(defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) + if cached_defers: + def update_results_dict(res): + results.update(res) + return results + + return preserve_context_over_deferred(defer.gatherResults( + cached_defers.values(), + consumeErrors=True, + ).addCallback(update_results_dict).addErrback( + unwrapFirstError + )) + else: + return results obj.__dict__[self.orig.__name__] = wrapped