Handle cancellation in EventsWorkerStore._get_events_from_cache_or_db (#12529)

Multiple calls to `EventsWorkerStore._get_events_from_cache_or_db` can
reuse the same database fetch, which is initiated by the first call.
Ensure that cancelling the first call doesn't cancel the other calls
sharing the same database fetch.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-04-25 19:39:17 +01:00 committed by GitHub
parent 813d728d09
commit 8a87b4435a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 167 additions and 34 deletions

View file

@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
"""Fetches the events in `missing_event_ids` from the database.
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
Also creates entries in `self._current_event_fetches` to allow
concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
"""
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
event_entry_map.update(missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
self._current_event_fetches[event_id] = fetching_deferred
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
return missing_events
# We must allow the database fetch to complete in the presence of
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
get_missing_events_from_db()
)
event_entry_map.update(missing_events)
if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results