Add some more type annotations to Cache

This commit is contained in:
Richard van der Hoff 2020-10-14 19:40:53 +01:00
parent 629a951b49
commit 7eff59ec91
3 changed files with 62 additions and 24 deletions

View File

@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) ) # type: Cache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -13,12 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import functools import functools
import inspect import inspect
import logging import logging
import threading import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from typing import (
Any,
Callable,
Generic,
Iterable,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from prometheus_client import Gauge from prometheus_client import Gauge
@ -38,6 +49,8 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any] CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
KT = TypeVar("KT")
VT = TypeVar("VT")
class _CachedFunction(Generic[F]): class _CachedFunction(Generic[F]):
@ -61,13 +74,19 @@ cache_pending_metric = Gauge(
["name"], ["name"],
) )
_CacheSentinel = object()
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class CacheEntry: class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"] __slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks): def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred self.deferred = deferred
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
self.invalidated = False self.invalidated = False
@ -80,7 +99,13 @@ class CacheEntry:
self.callbacks.clear() self.callbacks.clear()
class Cache: class Cache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""
__slots__ = ( __slots__ = (
"cache", "cache",
"name", "name",
@ -103,19 +128,23 @@ class Cache:
Args: Args:
name: The name of the cache name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry, iterable: If True, count each item in the cached object as an entry,
rather than each cached object rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries` config file affect `max_entries`
Returns:
Cache
""" """
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, max_size=max_entries,
keylen=keylen, keylen=keylen,
@ -155,7 +184,13 @@ class Cache:
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches. """Looks the key up in the caches.
Args: Args:
@ -166,30 +201,32 @@ class Cache:
update_metrics (bool): whether to update the cache hit rate metrics update_metrics (bool): whether to update the cache hit rate metrics
Returns: Returns:
Either an ObservableDeferred or the raw result Either an ObservableDeferred or the result itself
""" """
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel) val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _CacheSentinel: if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks) val.callbacks.update(callbacks)
if update_metrics: if update_metrics:
self.metrics.inc_hits() self.metrics.inc_hits()
return val.deferred return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _CacheSentinel: if val is not _Sentinel.sentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
if update_metrics: if update_metrics:
self.metrics.inc_misses() self.metrics.inc_misses()
if default is _CacheSentinel: if default is _Sentinel.sentinel:
raise KeyError() raise KeyError()
else: else:
return default return default
def set(self, key, value, callback=None): def set(
self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred): if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred") raise TypeError("not a Deferred")
@ -248,7 +285,7 @@ class Cache:
observer.addCallbacks(cb, eb) observer.addCallbacks(cb, eb)
return observable return observable
def prefill(self, key, value, callback=None): def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks) self.cache.set(key, value, callbacks=callbacks)
@ -267,7 +304,7 @@ class Cache:
if entry: if entry:
entry.invalidate() entry.invalidate()
def invalidate_many(self, key): def invalidate_many(self, key: KT):
self.check_thread() self.check_thread()
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),)) raise TypeError("The cache key must be a tuple not %r" % (type(key),))
@ -275,7 +312,7 @@ class Cache:
# if we have a pending lookup for this key, remove it from the # if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None) entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None: if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict): for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate() entry.invalidate()
@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
) ) # type: Cache[Tuple, Any]
def get_cache_key_gen(args, kwargs): def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into """Given some args/kwargs return a generator that resolves into

View File

@ -64,7 +64,8 @@ class LruCache:
Args: Args:
max_size: The maximum amount of entries the cache can hold max_size: The maximum amount of entries the cache can hold
keylen: The length of the tuple used as the cache key keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
cache_type (type): cache_type (type):
type of underlying cache to be used. Typically one of dict type of underlying cache to be used. Typically one of dict