diff --git a/changelog.d/13242.misc b/changelog.d/13242.misc new file mode 100644 index 000000000..7f8ec0815 --- /dev/null +++ b/changelog.d/13242.misc @@ -0,0 +1 @@ +Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e21ab0851..6a6d0dcd7 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.util.async_helpers import delay_cancellation +from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -818,12 +818,14 @@ class DatabasePool: ) for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) + await maybe_awaitable(after_callback(*after_args, **after_kwargs)) return cast(R, result) except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) + for exception_callback, after_args, after_kwargs in exception_callbacks: + await maybe_awaitable( + exception_callback(*after_args, **after_kwargs) + ) raise # To handle cancellation, we ensure that `after_callback`s and diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 1653a6a9b..2367ddeea 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -193,7 +193,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): relates_to: Optional[str], backfilled: bool, ) -> None: - self._invalidate_get_event_cache(event_id) + # This invalidates any local in-memory cached event objects, the original + # process triggering the invalidation is responsible for clearing any external + # cached objects. + self._invalidate_local_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -208,7 +211,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: - self._invalidate_get_event_cache(redacts) + self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. self.get_relations_for_event.invalidate((redacts,)) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index eb4efbb93..fa2266ba2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1669,9 +1669,9 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill() -> None: + async def prefill() -> None: for cache_entry in to_prefill: - self.store._get_event_cache.set( + await self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 621f92e23..f3935bfea 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import AsyncLruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache: AsyncLruCache[ + Tuple[str], EventCacheEntry + ] = AsyncLruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -598,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ - event_entry_map = self._get_events_from_cache( + event_entry_map = await self._get_events_from_cache( event_ids, ) @@ -710,12 +712,22 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: - self._get_event_cache.invalidate((event_id,)) + async def _invalidate_get_event_cache(self, event_id: str) -> None: + # First we invalidate the asynchronous cache instance. This may include + # out-of-process caches such as Redis/memcache. Once complete we can + # invalidate any in memory cache. The ordering is important here to + # ensure we don't pull in any remote invalid value after we invalidate + # the in-memory cache. + await self._get_event_cache.invalidate((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _get_events_from_cache( + def _invalidate_local_get_event_cache(self, event_id: str) -> None: + self._get_event_cache.invalidate_local((event_id,)) + self._event_ref.pop(event_id, None) + self._current_event_fetches.pop(event_id, None) + + async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. @@ -730,7 +742,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: # First check if it's in the event cache - ret = self._get_event_cache.get( + ret = await self._get_event_cache.get( (event_id,), None, update_metrics=update_metrics ) if ret: @@ -752,7 +764,7 @@ class EventsWorkerStore(SQLBaseStore): # We add the entry back into the cache as we want to keep # recently queried events in the cache. - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) return event_map @@ -1129,7 +1141,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry if not redacted_event: @@ -1363,7 +1375,9 @@ class EventsWorkerStore(SQLBaseStore): # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + (rid, eid) + for (rid, eid) in keys + if await self._get_event_cache.contains((eid,)) } results = dict.fromkeys(cache_results, True) remaining = [k for k in keys if k not in cache_results] diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 87b0d0903..549ce07c1 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -302,7 +302,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) - self._invalidate_get_event_cache(event_id) + txn.call_after(self._invalidate_get_event_cache, event_id) logger.info("[purge] done") diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 71a65d565..105a51867 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -843,7 +843,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) + event_map = await self._get_events_from_cache( + member_event_ids, update_metrics=False + ) missing_member_event_ids = [] for event_id in member_event_ids: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 8ed5325c5..31f41fec8 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -730,3 +730,41 @@ class LruCache(Generic[KT, VT]): # This happens e.g. in the sync code where we have an expiring cache of # lru caches. self.clear() + + +class AsyncLruCache(Generic[KT, VT]): + """ + An asynchronous wrapper around a subset of the LruCache API. + + On its own this doesn't change the behaviour but allows subclasses that + utilize external cache systems that require await behaviour to be created. + """ + + def __init__(self, *args, **kwargs): # type: ignore + self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) + + async def get( + self, key: KT, default: Optional[T] = None, update_metrics: bool = True + ) -> Optional[VT]: + return self._lru_cache.get(key, update_metrics=update_metrics) + + async def set(self, key: KT, value: VT) -> None: + self._lru_cache.set(key, value) + + async def invalidate(self, key: KT) -> None: + # This method should invalidate any external cache and then invalidate the LruCache. + return self._lru_cache.invalidate(key) + + def invalidate_local(self, key: KT) -> None: + """Remove an entry from the local cache + + This variant of `invalidate` is useful if we know that the external + cache has already been invalidated. + """ + return self._lru_cache.invalidate(key) + + async def contains(self, key: KT) -> bool: + return self._lru_cache.contains(key) + + async def clear(self) -> None: + self._lru_cache.clear() diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ecc7cc646..e3f38fbcc 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() # The rooms should be excluded from the sync response. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 38963ce4a..46d829b06 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 8dfaa0559..9c1182ed1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)