Eliminate a few Anys in LruCache type hints (#11453)

This commit is contained in:
Sean Quah 2021-11-30 15:39:07 +00:00 committed by GitHub
parent 432a174bc1
commit 5a0b652d36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 19 deletions

1
changelog.d/11453.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints for `LruCache`.

View File

@ -22,6 +22,7 @@ from typing import (
Iterable, Iterable,
MutableMapping, MutableMapping,
Optional, Optional,
Sized,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -104,7 +105,13 @@ class DeferredCache(Generic[KT, VT]):
max_size=max_entries, max_size=max_entries,
cache_name=name, cache_name=name,
cache_type=cache_type, cache_type=cache_type,
size_callback=(lambda d: len(d) or 1) if iterable else None, size_callback=(
(lambda d: len(cast(Sized, d)) or 1)
# Argument 1 to "len" has incompatible type "VT"; expected "Sized"
# We trust that `VT` is `Sized` when `iterable` is `True`
if iterable
else None
),
metrics_collection_callback=metrics_cb, metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config, apply_cache_factor_from_config=apply_cache_factor_from_config,
prune_unread_entries=prune_unread_entries, prune_unread_entries=prune_unread_entries,

View File

@ -15,14 +15,15 @@
import logging import logging
import threading import threading
import weakref import weakref
from enum import Enum
from functools import wraps from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Collection, Collection,
Dict,
Generic, Generic,
Iterable,
List, List,
Optional, Optional,
Type, Type,
@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]):
root: "ListNode[_Node]", root: "ListNode[_Node]",
key: KT, key: KT,
value: VT, value: VT,
cache: "weakref.ReferenceType[LruCache]", cache: "weakref.ReferenceType[LruCache[KT, VT]]",
clock: Clock, clock: Clock,
callbacks: Collection[Callable[[], None]] = (), callbacks: Collection[Callable[[], None]] = (),
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
@ -290,6 +291,12 @@ class _Node(Generic[KT, VT]):
self._global_list_node.update_last_access(clock) self._global_list_node.update_last_access(clock)
class _Sentinel(Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class LruCache(Generic[KT, VT]): class LruCache(Generic[KT, VT]):
""" """
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@ -302,7 +309,7 @@ class LruCache(Generic[KT, VT]):
max_size: int, max_size: int,
cache_name: Optional[str] = None, cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache]] = dict, cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None, size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True, apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None, clock: Optional[Clock] = None,
@ -339,7 +346,7 @@ class LruCache(Generic[KT, VT]):
else: else:
real_clock = clock real_clock = clock
cache = cache_type() cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
self.cache = cache # Used for introspection. self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config self.apply_cache_factor_from_config = apply_cache_factor_from_config
@ -374,7 +381,7 @@ class LruCache(Generic[KT, VT]):
# creating more each time we create a `_Node`. # creating more each time we create a `_Node`.
weak_ref_to_self = weakref.ref(self) weak_ref_to_self = weakref.ref(self)
list_root = ListNode[_Node].create_root_node() list_root = ListNode[_Node[KT, VT]].create_root_node()
lock = threading.Lock() lock = threading.Lock()
@ -422,7 +429,7 @@ class LruCache(Generic[KT, VT]):
def add_node( def add_node(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
) -> None: ) -> None:
node = _Node( node: _Node[KT, VT] = _Node(
list_root, list_root,
key, key,
value, value,
@ -439,10 +446,10 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics: if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory) metrics.inc_memory_usage(node.memory)
def move_node_to_front(node: _Node) -> None: def move_node_to_front(node: _Node[KT, VT]) -> None:
node.move_to_front(real_clock, list_root) node.move_to_front(real_clock, list_root)
def delete_node(node: _Node) -> int: def delete_node(node: _Node[KT, VT]) -> int:
node.drop_from_lists() node.drop_from_lists()
deleted_len = 1 deleted_len = 1
@ -496,7 +503,7 @@ class LruCache(Generic[KT, VT]):
@synchronized @synchronized
def cache_set( def cache_set(
key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
) -> None: ) -> None:
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -590,8 +597,6 @@ class LruCache(Generic[KT, VT]):
def cache_contains(key: KT) -> bool: def cache_contains(key: KT) -> bool:
return key in cache return key in cache
self.sentinel = object()
# make sure that we clear out any excess entries after we get resized. # make sure that we clear out any excess entries after we get resized.
self._on_resize = evict self._on_resize = evict
@ -608,18 +613,18 @@ class LruCache(Generic[KT, VT]):
self.clear = cache_clear self.clear = cache_clear
def __getitem__(self, key: KT) -> VT: def __getitem__(self, key: KT) -> VT:
result = self.get(key, self.sentinel) result = self.get(key, _Sentinel.sentinel)
if result is self.sentinel: if result is _Sentinel.sentinel:
raise KeyError() raise KeyError()
else: else:
return cast(VT, result) return result
def __setitem__(self, key: KT, value: VT) -> None: def __setitem__(self, key: KT, value: VT) -> None:
self.set(key, value) self.set(key, value)
def __delitem__(self, key: KT, value: VT) -> None: def __delitem__(self, key: KT, value: VT) -> None:
result = self.pop(key, self.sentinel) result = self.pop(key, _Sentinel.sentinel)
if result is self.sentinel: if result is _Sentinel.sentinel:
raise KeyError() raise KeyError()
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -84,7 +84,7 @@ class ListNode(Generic[P]):
# immediately rather than at the next GC. # immediately rather than at the next GC.
self.cache_entry = None self.cache_entry = None
def move_after(self, node: "ListNode") -> None: def move_after(self, node: "ListNode[P]") -> None:
"""Move this node from its current location in the list to after the """Move this node from its current location in the list to after the
given node. given node.
""" """
@ -122,7 +122,7 @@ class ListNode(Generic[P]):
self.prev_node = None self.prev_node = None
self.next_node = None self.next_node = None
def _refs_insert_after(self, node: "ListNode") -> None: def _refs_insert_after(self, node: "ListNode[P]") -> None:
"""Internal method to insert the node after the given node.""" """Internal method to insert the node after the given node."""
# This method should only be called when we're not already in the list. # This method should only be called when we're not already in the list.