type annotations for LruCache

This commit is contained in:
Richard van der Hoff 2020-10-16 15:56:39 +01:00
parent 3ee17585cd
commit 0ec0bc3886
5 changed files with 89 additions and 31 deletions

View file

@ -15,12 +15,30 @@
import threading
from functools import wraps
from typing import Callable, Optional, Type, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
T = TypeVar("T")
FT = TypeVar("FT", bound=Callable[..., Any])
KT = TypeVar("KT")
VT = TypeVar("VT")
def enumerate_leaves(node, depth):
if depth == 0:
@ -42,7 +60,7 @@ class _Node:
self.callbacks = callbacks
class LruCache:
class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@ -128,13 +146,13 @@ class LruCache:
if metrics:
metrics.inc_evictions(evicted_len)
def synchronized(f):
def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@ -188,8 +206,31 @@ class LruCache:
node.callbacks.clear()
return deleted_len
@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@synchronized
def cache_get(key, default=None, callbacks=[], update_metrics=True):
def cache_get(
key: KT,
default=None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@ -203,7 +244,7 @@ class LruCache:
return default
@synchronized
def cache_set(key, value, callbacks=[]):
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@ -232,7 +273,7 @@ class LruCache:
evict()
@synchronized
def cache_set_default(key, value):
def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@ -241,8 +282,16 @@ class LruCache:
evict()
return value
@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]:
...
@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...
@synchronized
def cache_pop(key, default=None):
def cache_pop(key: KT, default=None):
node = cache.get(key, None)
if node:
delete_node(node)
@ -252,18 +301,18 @@ class LruCache:
return default
@synchronized
def cache_del_multi(key):
def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
for leaf in enumerate_leaves(popped, keylen - len(key)):
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
def cache_clear():
def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@ -274,7 +323,7 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()