Support any process writing to cache invalidation stream. (#7436)

This commit is contained in:
Erik Johnston 2020-05-07 13:51:08 +01:00 committed by GitHub
parent 2929ce29d6
commit d7983b63a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 225 additions and 230 deletions

View file

@ -16,11 +16,10 @@
import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
from twisted.internet import defer
from typing import Any, Iterable, Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit):
def __init__(self, database: Database, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
"""
if last_id == current_id:
return defer.succeed([])
return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
sql = """
SELECT stream_id, cache_func, keys, invalidation_ts
FROM cache_invalidation_stream_by_instance
WHERE stream_id > ? AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
return self.db.runInteraction(
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
cache_func = getattr(self, cache_name, None)
if not cache_func:
return
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
cache_func.invalidate(keys)
await self.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
keys,
)
super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
txn.call_on_exception(ctx.__exit__, None, None, None)
txn.call_after(ctx.__exit__, None, None, None)
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn(
txn,
table="cache_invalidation_stream",
table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
def get_cache_stream_token(self):
def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
return self._cache_id_gen.get_current_token(instance_name)
else:
return 0