Type annotations for LruCache (#8562)

* type annotations for LruCache

* changelog

* Apply suggestions from code review

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* review comments

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Richard van der Hoff 2020-10-16 17:06:50 +01:00 committed by GitHub
commit d6094176d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 31 deletions

1
changelog.d/8562.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations for `LruCache`.

View File

@ -69,7 +69,9 @@ class Auth:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.token_cache = LruCache(10000, "token_cache")
self.token_cache = LruCache(
10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs)

View File

@ -16,7 +16,7 @@
import logging
import re
from typing import Any, Dict, List, Optional, Pattern, Union
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
from synapse.types import UserID
@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
r = re.escape(display_name)
r = _re_word_boundary(r)
r = re.compile(r, flags=re.IGNORECASE)
r1 = re.escape(display_name)
r1 = _re_word_boundary(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
return r.search(body)
return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000, "regex_push_cache")
regex_cache = LruCache(
50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool], Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return r.search(value)
return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False

View File

@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
size_callback=(lambda d: len(d)) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
) # type: LruCache[KT, VT]
self.thread = None # type: Optional[threading.Thread]
@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()

View File

@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import threading
from collections import namedtuple
from typing import Any
from synapse.util.caches.lrucache import LruCache
@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value)
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
class Sentinel:
__slots__ = []
self.sentinel = Sentinel()
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@ -76,8 +80,8 @@ class DictionaryCache:
Returns:
DictionaryEntry
"""
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)

View File

@ -15,12 +15,35 @@
import threading
from functools import wraps
from typing import Callable, Optional, Type, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])
# Key and Value type for the cache
KT = TypeVar("KT")
VT = TypeVar("VT")
# a general type var, distinct from either KT or VT
T = TypeVar("T")
def enumerate_leaves(node, depth):
if depth == 0:
@ -42,7 +65,7 @@ class _Node:
self.callbacks = callbacks
class LruCache:
class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@ -128,13 +151,13 @@ class LruCache:
if metrics:
metrics.inc_evictions(evicted_len)
def synchronized(f):
def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@ -188,8 +211,31 @@ class LruCache:
node.callbacks.clear()
return deleted_len
@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@synchronized
def cache_get(key, default=None, callbacks=[], update_metrics=True):
def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@ -203,7 +249,7 @@ class LruCache:
return default
@synchronized
def cache_set(key, value, callbacks=[]):
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@ -232,7 +278,7 @@ class LruCache:
evict()
@synchronized
def cache_set_default(key, value):
def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@ -241,8 +287,16 @@ class LruCache:
evict()
return value
@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
...
@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...
@synchronized
def cache_pop(key, default=None):
def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@ -252,18 +306,18 @@ class LruCache:
return default
@synchronized
def cache_del_multi(key):
def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
for leaf in enumerate_leaves(popped, keylen - len(key)):
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
def cache_clear():
def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@ -274,7 +328,7 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()