mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-13 16:32:17 -04:00
Support any process writing to cache invalidation stream. (#7436)
This commit is contained in:
parent
2929ce29d6
commit
d7983b63a6
26 changed files with 225 additions and 230 deletions
|
@ -16,11 +16,10 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Iterable, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
|
@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
|||
|
||||
|
||||
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
async def get_all_updated_caches(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
):
|
||||
"""Fetches cache invalidation rows between the two given IDs written
|
||||
by the given instance. Returns at most `limit` rows.
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
return []
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
sql = (
|
||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit))
|
||||
sql = """
|
||||
SELECT stream_id, cache_func, keys, invalidation_ts
|
||||
FROM cache_invalidation_stream_by_instance
|
||||
WHERE stream_id > ? AND instance_name = ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, instance_name, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
return await self.db.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(instance_name, token)
|
||||
|
||||
class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
for row in rows:
|
||||
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
||||
if row.keys is None:
|
||||
raise Exception(
|
||||
"Can't send an 'invalidate all' for current state cache"
|
||||
)
|
||||
|
||||
This should only be used to invalidate caches where slaves won't
|
||||
otherwise know from other replication streams that the cache should
|
||||
be invalidated.
|
||||
"""
|
||||
cache_func = getattr(self, cache_name, None)
|
||||
if not cache_func:
|
||||
return
|
||||
room_id = row.keys[0]
|
||||
members_changed = set(row.keys[1:])
|
||||
self._invalidate_state_caches(room_id, members_changed)
|
||||
else:
|
||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
|
||||
cache_func.invalidate(keys)
|
||||
await self.runInteraction(
|
||||
"invalidate_cache_and_stream",
|
||||
self._send_invalidation_to_replication,
|
||||
cache_func.__name__,
|
||||
keys,
|
||||
)
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
|
@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
|||
# the transaction. However, we want to only get an ID when we want
|
||||
# to use it, here, so we need to call __enter__ manually, and have
|
||||
# __exit__ called after the transaction finishes.
|
||||
ctx = self._cache_id_gen.get_next()
|
||||
stream_id = ctx.__enter__()
|
||||
txn.call_on_exception(ctx.__exit__, None, None, None)
|
||||
txn.call_after(ctx.__exit__, None, None, None)
|
||||
stream_id = self._cache_id_gen.get_next_txn(txn)
|
||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||
|
||||
if keys is not None:
|
||||
|
@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
|||
|
||||
self.db.simple_insert_txn(
|
||||
txn,
|
||||
table="cache_invalidation_stream",
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"cache_func": cache_name,
|
||||
"keys": keys,
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
},
|
||||
)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
def get_cache_stream_token(self, instance_name):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
return self._cache_id_gen.get_current_token(instance_name)
|
||||
else:
|
||||
return 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue