diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index a29ea8144..3331ea9eb 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -21,7 +21,7 @@ class TreeCache(object): node = self.root for k in key[:-1]: node = node.setdefault(k, {}) - node[key[-1]] = value + node[key[-1]] = _Entry(value) self.size += 1 def get(self, key, default=None): @@ -30,7 +30,7 @@ class TreeCache(object): node = node.get(k, None) if node is None: return default - return node.get(key[-1], default) + return node.get(key[-1], _Entry(default)).value def clear(self): self.size = 0 @@ -60,8 +60,33 @@ class TreeCache(object): break node_and_keys[i+1][0].pop(k) - self.size -= 1 + popped, cnt = _strip_and_count_entires(popped) + self.size -= cnt return popped def __len__(self): return self.size + + +class _Entry(object): + __slots__ = ["value"] + + def __init__(self, value): + object.__setattr__(self, "value", value) + + +def _strip_and_count_entires(d): + """Takes an _Entry or dict with leaves of _Entry's, and either returns the + value or a dictionary with _Entry's replaced by their values. + + Also returns the count of _Entry's + """ + if isinstance(d, dict): + cnt = 0 + for key, value in d.items(): + v, n = _strip_and_count_entires(value) + d[key] = v + cnt += n + return d, cnt + else: + return d.value, 1