Fix have_seen_event cache not being invalidated (#13863)

Fix https://github.com/matrix-org/synapse/issues/13856
Fix https://github.com/matrix-org/synapse/issues/13865

> Discovered while trying to make Synapse fast enough for [this MSC2716 test for importing many batches](https://github.com/matrix-org/complement/pull/214#discussion_r741678240). As an example, disabling the `have_seen_event` cache saves 10 seconds for each `/messages` request in that MSC2716 Complement test because we're not making as many federation requests for `/state` (speeding up `have_seen_event` itself is related to https://github.com/matrix-org/synapse/issues/13625) 
> 
> But this will also make `/messages` faster in general so we can include it in the [faster `/messages` milestone](https://github.com/matrix-org/synapse/milestone/11).
> 
> *-- https://github.com/matrix-org/synapse/issues/13856*


### The problem

`_invalidate_caches_for_event` doesn't run in monolith mode which means we never even tried to clear the `have_seen_event` and other caches. And even in worker mode, it only runs on the workers, not the master (AFAICT).

Additionally there was bug with the key being wrong so `_invalidate_caches_for_event` never invalidates the `have_seen_event` cache even when it does run.

Because we were using the `@cachedList` wrong, it was putting items in the cache under keys like `((room_id, event_id),)` with a `set` in a `set` (ex. `(('!TnCIJPKzdQdUlIyXdQ:test', '$Iu0eqEBN7qcyF1S9B3oNB3I91v2o5YOgRNPwi_78s-k'),)`) and we we're trying to invalidate with just `(room_id, event_id)` which did nothing.
This commit is contained in:
Eric Eastwood 2022-09-27 15:55:43 -05:00 committed by GitHub
parent 35e9d6a616
commit 29269d9d3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 67 deletions

1
changelog.d/13863.bugfix Normal file
View File

@ -0,0 +1 @@
Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls.

View File

@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
# the batches as big as possible. # the batches as big as possible.
results: Set[str] = set() results: Set[str] = set()
for chunk in batch_iter(event_ids, 500): for event_ids_chunk in batch_iter(event_ids, 500):
r = await self._have_seen_events_dict( events_seen_dict = await self._have_seen_events_dict(
[(room_id, event_id) for event_id in chunk] room_id, event_ids_chunk
)
results.update(
eid for (eid, have_event) in events_seen_dict.items() if have_event
) )
results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
return results return results
@cachedList(cached_method_name="have_seen_event", list_name="keys") @cachedList(cached_method_name="have_seen_event", list_name="event_ids")
async def _have_seen_events_dict( async def _have_seen_events_dict(
self, keys: Collection[Tuple[str, str]] self,
) -> Dict[Tuple[str, str], bool]: room_id: str,
event_ids: Collection[str],
) -> Dict[str, bool]:
"""Helper for have_seen_events """Helper for have_seen_events
Returns: Returns:
a dict {(room_id, event_id)-> bool} a dict {event_id -> bool}
""" """
# if the event cache contains the event, obviously we've seen it. # if the event cache contains the event, obviously we've seen it.
cache_results = { cache_results = {
(rid, eid) event_id
for (rid, eid) in keys for event_id in event_ids
if await self._get_event_cache.contains((eid,)) if await self._get_event_cache.contains((event_id,))
} }
results = dict.fromkeys(cache_results, True) results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results] remaining = [
event_id for event_id in event_ids if event_id not in cache_results
]
if not remaining: if not remaining:
return results return results
@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE " sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] txn.database_engine, "e.event_id", remaining
) )
txn.execute(sql + clause, args) txn.execute(sql + clause, args)
found_events = {eid for eid, in txn} found_events = {eid for eid, in txn}
# ... and then we can update the results for each key # ... and then we can update the results for each key
results.update( results.update({eid: (eid in found_events) for eid in remaining})
{(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results return results
@cached(max_entries=100000, tree=True) @cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool: async def have_seen_event(self, room_id: str, event_id: str) -> bool:
res = await self._have_seen_events_dict(((room_id, event_id),)) res = await self._have_seen_events_dict(room_id, [event_id])
return res[(room_id, event_id)] return res[event_id]
def _get_current_state_event_counts_txn( def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str self, txn: LoggingTransaction, room_id: str

View File

@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cache: DeferredCache[CacheKey, Any] = cached_method.cache cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args num_args = cached_method.num_args
if num_args != self.num_args:
raise Exception(
"Number of args (%s) does not match underlying cache_method_name=%s (%s)."
% (self.num_args, self.cached_method_name, num_args)
)
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its # If we're passed a cache_context then we'll want to call its

View File

@ -35,66 +35,45 @@ from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest from tests import unittest
from tests.test_utils.event_injection import create_event, inject_event
class HaveSeenEventsTestCase(unittest.HomeserverTestCase): class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main self.store: EventsWorkerStore = hs.get_datastores().main
# insert some test data self.user = self.register_user("user", "pass")
for rid in ("room1", "room2"): self.token = self.login(self.user, "pass")
self.get_success( self.room_id = self.helper.create_room_as(self.user, tok=self.token)
self.store.db_pool.simple_insert(
"rooms",
{"room_id": rid, "room_version": 4},
)
)
self.event_ids: List[str] = [] self.event_ids: List[str] = []
for idx, rid in enumerate( for i in range(3):
( event = self.get_success(
"room1", inject_event(
"room1", hs,
"room1", room_version=RoomVersions.V7.identifier,
"room2", room_id=self.room_id,
sender=self.user,
type="test_event_type",
content={"body": f"foobarbaz{i}"},
)
) )
):
event_json = {"type": f"test {idx}", "room_id": rid}
event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
event_id = event.event_id
self.get_success( self.event_ids.append(event.event_id)
self.store.db_pool.simple_insert(
"events",
{
"event_id": event_id,
"room_id": rid,
"topological_ordering": idx,
"stream_ordering": idx,
"type": event.type,
"processed": True,
"outlier": False,
},
)
)
self.get_success(
self.store.db_pool.simple_insert(
"event_json",
{
"event_id": event_id,
"room_id": rid,
"json": json.dumps(event_json),
"internal_metadata": "{}",
"format_version": 3,
},
)
)
self.event_ids.append(event_id)
def test_simple(self): def test_simple(self):
with LoggingContext(name="test") as ctx: with LoggingContext(name="test") as ctx:
res = self.get_success( res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
)
) )
self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(res, {self.event_ids[0]})
@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# a second lookup of the same events should cause no queries # a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx: with LoggingContext(name="test") as ctx:
res = self.get_success( res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
)
) )
self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
@ -116,11 +97,86 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# looking it up should now cause no db hits # looking it up should now cause no db hits
with LoggingContext(name="test") as ctx: with LoggingContext(name="test") as ctx:
res = self.get_success( res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0]]) self.store.have_seen_events(self.room_id, [self.event_ids[0]])
) )
self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
def test_persisting_event_invalidates_cache(self):
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
the updated value.
"""
event, event_context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
sender=self.user,
type="test_event_type",
content={"body": "garply"},
)
)
with LoggingContext(name="test") as ctx:
# First, check `have_seen_event` for an event we have not seen yet
# to prime the cache with a `false` value.
res = self.get_success(
self.store.have_seen_events(event.room_id, [event.event_id])
)
self.assertEqual(res, set())
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
self.get_success(
persistence.persist_event(
event,
event_context,
)
)
with LoggingContext(name="test") as ctx:
# Check `have_seen_event` again and we should see the updated fact
# that we have now seen the event after persisting it.
res = self.get_success(
self.store.have_seen_events(event.room_id, [event.event_id])
)
self.assertEqual(res, {event.event_id})
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
def test_invalidate_cache_by_room_id(self):
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
"""
with LoggingContext(name="test") as ctx:
# Prime the cache with some values
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
# Clear the cache with any events associated with the `room_id`
self.store.have_seen_event.invalidate((self.room_id,))
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))
# Since we cleared the cache, it should result in another db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
class EventCacheTestCase(unittest.HomeserverTestCase): class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works.""" """Test that the various layers of event cache works."""

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set from typing import Iterable, Set, Tuple
from unittest import mock from unittest import mock
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.inner_context_was_finished, "Tried to restart a finished logcontext" obj.inner_context_was_finished, "Tried to restart a finished logcontext"
) )
self.assertEqual(current_context(), SENTINEL_CONTEXT) self.assertEqual(current_context(), SENTINEL_CONTEXT)
def test_num_args_mismatch(self):
"""
Make sure someone does not accidentally use @cachedList on a method with
a mismatch in the number args to the underlying single cache method.
"""
class Cls:
@descriptors.cached(tree=True)
def fn(self, room_id, event_id):
pass
# This is wrong ❌. `@cachedList` expects to be given the same number
# of arguments as the underlying cached function, just with one of
# the arguments being an iterable
@descriptors.cachedList(cached_method_name="fn", list_name="keys")
def list_fn(self, keys: Iterable[Tuple[str, str]]):
pass
# Corrected syntax ✅
#
# @cachedList(cached_method_name="fn", list_name="event_ids")
# async def list_fn(
# self, room_id: str, event_ids: Collection[str],
# )
obj = Cls()
# Make sure this raises an error about the arg mismatch
with self.assertRaises(Exception):
obj.list_fn([("foo", "bar")])