type annotations for LruCache

This commit is contained in:
Richard van der Hoff 2020-10-16 15:56:39 +01:00
parent 3ee17585cd
commit 0ec0bc3886
5 changed files with 89 additions and 31 deletions

View File

@ -69,7 +69,9 @@ class Auth:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.token_cache = LruCache(10000, "token_cache") self.token_cache = LruCache(
10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs) self._auth_blocking = AuthBlocking(self.hs)

View File

@ -16,7 +16,7 @@
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Pattern, Union from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import UserID from synapse.types import UserID
@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob. # Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None) r = regex_cache.get((display_name, False, True), None)
if not r: if not r:
r = re.escape(display_name) r1 = re.escape(display_name)
r = _re_word_boundary(r) r1 = _re_word_boundary(r1)
r = re.compile(r, flags=re.IGNORECASE) r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r regex_cache[(display_name, False, True)] = r
return r.search(body) return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]: def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None) return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000, "regex_push_cache") regex_cache = LruCache(
50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool],Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r: if not r:
r = _glob_to_re(glob, word_boundary) r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r regex_cache[(glob, True, word_boundary)] = r
return r.search(value) return bool(r.search(value))
except re.error: except re.error:
logger.warning("Failed to parse glob to regex: %r", glob) logger.warning("Failed to parse glob to regex: %r", glob)
return False return False

View File

@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
size_callback=(lambda d: len(d)) if iterable else None, size_callback=(lambda d: len(d)) if iterable else None,
metrics_collection_callback=metrics_cb, metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config, apply_cache_factor_from_config=apply_cache_factor_from_config,
) ) # type: LruCache[KT, VT]
self.thread = None # type: Optional[threading.Thread] self.thread = None # type: Optional[threading.Thread]
@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread() self.check_thread()
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),)) raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the # if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None: if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict): for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate() entry.invalidate()

View File

@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Any
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value) return len(self.value)
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DictionaryCache: class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key. fetching a subset of dictionary keys for a particular key.
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len) self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
self.name = name self.name = name
self.sequence = 0 self.sequence = 0
self.thread = None self.thread = None
class Sentinel:
__slots__ = []
self.sentinel = Sentinel()
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
@ -76,8 +80,8 @@ class DictionaryCache:
Returns: Returns:
DictionaryEntry DictionaryEntry
""" """
entry = self.cache.get(key, self.sentinel) entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not self.sentinel: if entry is not _Sentinel.sentinel:
if dict_keys is None: if dict_keys is None:
return DictionaryEntry( return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value) entry.full, entry.known_absent, dict(entry.value)

View File

@ -15,12 +15,30 @@
import threading import threading
from functools import wraps from functools import wraps
from typing import Callable, Optional, Type, Union from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Literal
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
T = TypeVar("T")
FT = TypeVar("FT", bound=Callable[..., Any])
KT = TypeVar("KT")
VT = TypeVar("VT")
def enumerate_leaves(node, depth): def enumerate_leaves(node, depth):
if depth == 0: if depth == 0:
@ -42,7 +60,7 @@ class _Node:
self.callbacks = callbacks self.callbacks = callbacks
class LruCache: class LruCache(Generic[KT, VT]):
""" """
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@ -128,13 +146,13 @@ class LruCache:
if metrics: if metrics:
metrics.inc_evictions(evicted_len) metrics.inc_evictions(evicted_len)
def synchronized(f): def synchronized(f: FT) -> FT:
@wraps(f) @wraps(f)
def inner(*args, **kwargs): def inner(*args, **kwargs):
with lock: with lock:
return f(*args, **kwargs) return f(*args, **kwargs)
return inner return cast(FT, inner)
cached_cache_len = [0] cached_cache_len = [0]
if size_callback is not None: if size_callback is not None:
@ -188,8 +206,31 @@ class LruCache:
node.callbacks.clear() node.callbacks.clear()
return deleted_len return deleted_len
@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@synchronized @synchronized
def cache_get(key, default=None, callbacks=[], update_metrics=True): def cache_get(
key: KT,
default=None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
@ -203,7 +244,7 @@ class LruCache:
return default return default
@synchronized @synchronized
def cache_set(key, value, callbacks=[]): def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
# We sometimes store large objects, e.g. dicts, which cause # We sometimes store large objects, e.g. dicts, which cause
@ -232,7 +273,7 @@ class LruCache:
evict() evict()
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
return node.value return node.value
@ -241,8 +282,16 @@ class LruCache:
evict() evict()
return value return value
@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]:
...
@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...
@synchronized @synchronized
def cache_pop(key, default=None): def cache_pop(key: KT, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
delete_node(node) delete_node(node)
@ -252,18 +301,18 @@ class LruCache:
return default return default
@synchronized @synchronized
def cache_del_multi(key): def cache_del_multi(key: KT) -> None:
""" """
This will only work if constructed with cache_type=TreeCache This will only work if constructed with cache_type=TreeCache
""" """
popped = cache.pop(key) popped = cache.pop(key)
if popped is None: if popped is None:
return return
for leaf in enumerate_leaves(popped, keylen - len(key)): for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf) delete_node(leaf)
@synchronized @synchronized
def cache_clear(): def cache_clear() -> None:
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values(): for node in cache.values():
@ -274,7 +323,7 @@ class LruCache:
cached_cache_len[0] = 0 cached_cache_len[0] = 0
@synchronized @synchronized
def cache_contains(key): def cache_contains(key: KT) -> bool:
return key in cache return key in cache
self.sentinel = object() self.sentinel = object()