Handle cancellation in EventsWorkerStore._get_events_from_cache_or_db (#12529)

Multiple calls to `EventsWorkerStore._get_events_from_cache_or_db` can
reuse the same database fetch, which is initiated by the first call.
Ensure that cancelling the first call doesn't cancel the other calls
sharing the same database fetch.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-04-25 19:39:17 +01:00 committed by GitHub
parent 813d728d09
commit 8a87b4435a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 167 additions and 34 deletions

1
changelog.d/12529.misc Normal file
View File

@ -0,0 +1 @@
Handle cancellation in `EventsWorkerStore._get_events_from_cache_or_db`.

View File

@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids.difference_update(already_fetching_ids) missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids: if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Add entries to `self._current_event_fetches` for each event we're async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
# going to pull from the DB. We use a single deferred that resolves """Fetches the events in `missing_event_ids` from the database.
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
# Note that _get_events_from_db is also responsible for turning db rows Also creates entries in `self._current_event_fetches` to allow
# into FrozenEvents (via _get_event_from_row), which involves seeing if concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
# the events have been redacted, and if so pulling the redaction event out """
# of the database to check it. log_ctx = current_context()
# log_ctx.record_event_fetch(len(missing_events_ids))
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
event_entry_map.update(missing_events) # Add entries to `self._current_event_fetches` for each event we're
except Exception as e: # going to pull from the DB. We use a single deferred that resolves
with PreserveLoggingContext(): # to all the events we pulled from the DB (this will result in this
fetching_deferred.errback(e) # function returning more events than requested, but that can happen
raise e # already due to `_get_events_from_db`).
finally: fetching_deferred: ObservableDeferred[
# Ensure that we mark these events as no longer being fetched. Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids: for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None) self._current_event_fetches[event_id] = fetching_deferred
with PreserveLoggingContext(): # Note that _get_events_from_db is also responsible for turning db rows
fetching_deferred.callback(missing_events) # into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
return missing_events
# We must allow the database fetch to complete in the presence of
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
get_missing_events_from_db()
)
event_entry_map.update(missing_events)
if already_fetching_deferreds: if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results # Wait for the other event requests to finish and add their results

View File

@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
import json import json
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator from typing import Generator, Tuple
from unittest import mock
from twisted.enterprise.adbapi import ConnectionPool from twisted.enterprise.adbapi import ConnectionPool
from twisted.internet.defer import ensureDeferred from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.api.room_versions import EventFormatVersions, RoomVersions
@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
# This next event fetch should succeed # This next event fetch should succeed
self.get_success(self.store.get_event(self.event_ids[0])) self.get_success(self.store.get_event(self.event_ids[0]))
class GetEventCancellationTestCase(unittest.HomeserverTestCase):
"""Test cancellation of `get_event` calls."""
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")
self.room = self.helper.create_room_as(self.user, tok=self.token)
res = self.helper.send(self.room, tok=self.token)
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
self.store._get_event_cache.clear()
@contextmanager
def blocking_get_event_calls(
self,
) -> Generator[
Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
]:
"""Starts two concurrent `get_event` calls for the same event.
Both `get_event` calls will use the same database fetch, which will be blocked
at the time this function returns.
Returns:
A tuple containing:
* A `Deferred` that unblocks the database fetch.
* A cancellable `Deferred` for the first `get_event` call.
* A cancellable `Deferred` for the second `get_event` call.
"""
# Patch `DatabasePool.runWithConnection` to block.
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection
async def runWithConnection(*args, **kwargs):
await unblock
return await original_runWithConnection(*args, **kwargs)
with mock.patch.object(
self.store.db_pool,
"runWithConnection",
new=runWithConnection,
):
ctx1 = LoggingContext("get_event1")
ctx2 = LoggingContext("get_event2")
async def get_event(ctx: LoggingContext) -> None:
with ctx:
await self.store.get_event(self.event_id)
get_event1 = ensureDeferred(get_event(ctx1))
get_event2 = ensureDeferred(get_event(ctx2))
# Both `get_event` calls ought to be blocked.
self.assertNoResult(get_event1)
self.assertNoResult(get_event2)
yield unblock, get_event1, get_event2
# Confirm that the two `get_event` calls shared the same database fetch.
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
def test_first_get_event_cancelled(self):
"""Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the
fetch to complete despite the cancellation. Furthermore, the first `get_event`
call must not abort before the fetch is complete, otherwise the fetch will be
using a finished logging context.
"""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the first `get_event` call.
get_event1.cancel()
# The first `get_event` call must not abort immediately, otherwise its
# logging context will be finished while it is still in use by the database
# fetch.
self.assertNoResult(get_event1)
# The second `get_event` call must not be cancelled.
self.assertNoResult(get_event2)
# Unblock the database fetch.
unblock.callback(None)
# A `CancelledError` should be raised out of the first `get_event` call.
exc = self.get_failure(get_event1, CancelledError).value
self.assertIsInstance(exc, CancelledError)
# The second `get_event` call should complete successfully.
self.get_success(get_event2)
def test_second_get_event_cancelled(self):
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.
get_event2.cancel()
# The first `get_event` call must not be cancelled.
self.assertNoResult(get_event1)
# The second `get_event` call gets cancelled immediately.
exc = self.get_failure(get_event2, CancelledError).value
self.assertIsInstance(exc, CancelledError)
# Unblock the database fetch.
unblock.callback(None)
# The first `get_event` call should complete successfully.
self.get_success(get_event1)