Add most missing type hints to synapse.util (#11328)

This commit is contained in:
Patrick Cloke 2021-11-16 08:47:36 -05:00 committed by GitHub
parent 3a1462f7e0
commit 7468723697
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 161 additions and 165 deletions

View file

@ -19,12 +19,15 @@ import logging
from typing import (
Any,
Callable,
Dict,
Generic,
Hashable,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
def __init__(
self,
orig: Callable[..., Any],
num_args: Optional[int],
cache_context: bool = False,
):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__(
self,
orig,
orig: Callable[..., Any],
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(
self,
orig,
max_entries=1000,
num_args=None,
tree=False,
cache_context=False,
iterable=False,
orig: Callable[..., Any],
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
def __get__(self, obj, owner):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None):
def __init__(
self,
orig: Callable[..., Any],
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
):
"""
Args:
orig (function)
cached_method_name (str): The name of the cached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int): number of positional arguments (excluding ``self``,
orig
cached_method_name: The name of the cached method.
list_name: Name of the argument which is the bulk lookup list
num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
def __get__(
self, obj: Optional[Any], objtype: Optional[Type] = None
) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {}
def update_results_dict(res, arg):
def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used.
if num_args == 1:
def arg_to_cache_key(arg):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
else:
keylist = list(keyargs)
def arg_to_cache_key(arg):
def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg
return tuple(keylist)
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val)
results[e] = val
def errback(f):
def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.