Add cache invalidation across workers to module API (#13667)

Signed-off-by: Mathieu Velten <mathieuv@matrix.org>
This commit is contained in:
Mathieu Velten 2022-09-21 15:32:01 +02:00 committed by GitHub
parent 16e1a9d9a7
commit 6bd8763804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 153 additions and 21 deletions

View File

@ -0,0 +1 @@
Add cache invalidation across workers to module API.

View File

@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
self, fullname: str self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]: ) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith( if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__" "synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith( ) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__" "synapse.util.caches.descriptors._LruCachedFunction.__call__"
): ):
@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct. """Fixes the `CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except: It already has *almost* the correct signature, except:

View File

@ -125,7 +125,7 @@ from synapse.types import (
) )
from synapse.util import Clock from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
if TYPE_CHECKING: if TYPE_CHECKING:
@ -836,6 +836,37 @@ class ModuleApi:
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
) )
def register_cached_function(self, cached_func: CachedFunction) -> None:
"""Register a cached function that should be invalidated across workers.
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
however invalidation that needs to go to other workers needs to call `invalidate_cache`
on the module API instead.
Args:
cached_function: The cached function that will be registered to receive invalidation
locally and from other workers.
"""
self._store.register_external_cached_function(
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
)
async def invalidate_cache(
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
) -> None:
"""Invalidate a cache entry of a cached function across workers. The cached function
needs to be registered on all workers first with `register_cached_function`.
Args:
cached_function: The cached function that needs an invalidation
keys: keys of the entry to invalidate, usually matching the arguments of the
cached function.
"""
cached_func.invalidate(keys)
await self._store.send_invalidation_to_replication(
f"{cached_func.__module__}.{cached_func.__name__}",
keys,
)
async def complete_sso_login_async( async def complete_sso_login_async(
self, self,
registered_user_id: str, registered_user_id: str,

View File

@ -15,12 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from abc import ABCMeta from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database
self.external_cached_functions: Dict[str, CachedFunction] = {}
def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache( def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]] self, cache_name: str, key: Optional[Collection[Any]]
) -> None: ) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the """Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers, cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache. where they may not have the cache.
@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
try: try:
cache = getattr(self, cache_name) cache = getattr(self, cache_name)
except AttributeError: except AttributeError:
# We probably haven't pulled in the cache in this worker, # Check if an externally defined module cache has been registered
# which is fine. cache = self.external_cached_functions.get(cache_name)
return if not cache:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return False
if key is None: if key is None:
cache.invalidate_all() cache.invalidate_all()
@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key)) invalidate_method(tuple(key))
return True
def register_external_cached_function(
self, cache_name: str, func: CachedFunction
) -> None:
self.external_cached_functions[cache_name] = func
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
""" """

View File

@ -33,7 +33,7 @@ from synapse.storage.database import (
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return return
cache_func.invalidate(keys) cache_func.invalidate(keys)
await self.db_pool.runInteraction( await self.send_invalidation_to_replication(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__, cache_func.__name__,
keys, keys,
) )
@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream( def _invalidate_cache_and_stream(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
cache_func: _CachedFunction, cache_func: CachedFunction,
keys: Tuple[Any, ...], keys: Tuple[Any, ...],
) -> None: ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream( def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None: ) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves """Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id] txn, CURRENT_STATE_CACHE_NAME, [room_id]
) )
async def send_invalidation_to_replication(
self, cache_name: str, keys: Optional[Collection[Any]]
) -> None:
await self.db_pool.runInteraction(
"send_invalidation_to_replication",
self._send_invalidation_to_replication,
cache_name,
keys,
)
def _send_invalidation_to_replication( def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None: ) -> None:

View File

@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]): class CachedFunction(Generic[F]):
invalidate: Any = None invalidate: Any = None
invalidate_all: Any = None invalidate_all: Any = None
prefill: Any = None prefill: Any = None
@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
return ret2 return ret2
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache wrapped.cache = cache
obj.__dict__[self.name] = wrapped obj.__dict__[self.name] = wrapped
@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret) return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1: if self.num_args == 1:
assert not self.tree assert not self.tree
@ -572,7 +572,7 @@ def cached(
iterable: bool = False, iterable: bool = False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
name: Optional[str] = None, name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor( func = lambda orig: DeferredCacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -585,7 +585,7 @@ def cached(
name=name, name=name,
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], CachedFunction[F]], func)
def cachedList( def cachedList(
@ -594,7 +594,7 @@ def cachedList(
list_name: str, list_name: str,
num_args: Optional[int] = None, num_args: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments Used to do batch lookups for an already created cache. One of the arguments
@ -631,7 +631,7 @@ def cachedList(
name=name, name=name,
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder( def _get_cache_key_builder(

View File

@ -0,0 +1,79 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import synapse
from synapse.module_api import cached
from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
FIRST_VALUE = "one"
SECOND_VALUE = "two"
KEY = "mykey"
class TestCache:
current_value = FIRST_VALUE
@cached()
async def cached_function(self, user_id: str) -> str:
return self.current_value
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
]
def test_module_cache_full_invalidation(self):
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
worker_cache = TestCache()
worker_hs.get_module_api().register_cached_function(
worker_cache.cached_function
)
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)
main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)
# Full invalidation on the main process, should be replicated on the worker that
# should returned the updated value too
self.get_success(
self.hs.get_module_api().invalidate_cache(
main_cache.cached_function, (KEY,)
)
)
self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
)