diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc new file mode 100644 index 000000000..7f22d4229 --- /dev/null +++ b/changelog.d/9973.misc @@ -0,0 +1 @@ +Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`. diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 70207420a..26bdead56 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate_many((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) else: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ecc1f935e..f7872501a 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_relations_for_event.invalidate((relates_to,)) + self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fd87ba71a..18f07d96d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 584532211..d1237c65c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): not be deleted. """ txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fd25c8112..897fa0663 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1748,9 +1748,9 @@ class PersistEventsStore: }, ) - txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) ) if rel_type == RelationTypes.REPLACE: @@ -1903,7 +1903,7 @@ class PersistEventsStore: for user_id in user_ids: txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) @@ -1917,7 +1917,7 @@ class PersistEventsStore: def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id,), ) txn.execute( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3647276ac..edeaacd7a 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) @@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) + txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) self.db_pool.simple_delete_txn( txn, diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 371e7e4dd..104413911 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -16,16 +16,7 @@ import enum import threading -from typing import ( - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - TypeVar, - Union, - cast, -) +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union from prometheus_client import Gauge @@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache = ( cache_type() - ) # type: MutableMapping[KT, CacheEntry] + ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] def metrics_cb(): cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) @@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]): self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): + """Delete a key, or tree of entries + + If the cache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the cache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. + """ self.check_thread() - self.cache.pop(key, None) + self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, which will (a) stop it being returned @@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]): # run the invalidation callbacks now, rather than waiting for the # deferred to resolve. if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - key = cast(KT, key) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): + # _pending_deferred_cache.pop should either return a CacheEntry, or, in the + # case of a TreeCache, a dict of keys to cache entries. Either way calling + # iterate_tree_cache_entry on it will do the right thing. + for entry in iterate_tree_cache_entry(entry): entry.invalidate() def invalidate_all(self): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2ac24a2f2..d77e8edee 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any]) class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any - invalidate_many = None # type: Any prefill = None # type: Any cache = None # type: Any num_args = None # type: Any @@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): ): super().__init__(orig, num_args=num_args, cache_context=cache_context) + if tree and self.num_args < 2: + raise RuntimeError( + "tree=True is nonsensical for cached functions with a single parameter" + ) + self.max_entries = max_entries self.tree = tree self.iterable = iterable @@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) if self.num_args == 1: + assert not self.tree wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill wrapped.invalidate_all = cache.invalidate_all diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 54df407ff..d89e9d9b1 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. - Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. """ @@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_del_multi(key: KT) -> None: + """Delete an entry, or tree of entries + + If the LruCache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the LruCache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. """ - This will only work if constructed with cache_type=TreeCache - """ - popped = cache.pop(key) + popped = cache.pop(key, None) if popped is None: return # for each deleted node, we now need to remove it from the linked list @@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + self.del_multi = cache_del_multi # `invalidate` is exposed for consistency with DeferredCache, so that it can be # invalidated by the cache invalidation replication stream. - self.invalidate = cache_pop - if cache_type is TreeCache: - self.del_multi = cache_del_multi + self.invalidate = cache_del_multi self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 73502a8b0..a6df81ebf 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -89,6 +89,9 @@ class TreeCache: value. If the key is partial, the TreeCacheNode corresponding to the part of the tree that was removed. """ + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + # a list of the nodes we have touched on the way down the tree nodes = [] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bbbc27669..0277998cb 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) self.assertEquals(callcount[0], 1) self.assertEquals(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") self.assertEquals(callcount[0], 2)