mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-08 20:25:03 -04:00
Add most missing type hints to synapse.util (#11328)
This commit is contained in:
parent
3a1462f7e0
commit
7468723697
10 changed files with 161 additions and 165 deletions
|
@ -19,12 +19,15 @@ import logging
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
|
@ -32,6 +35,7 @@ from typing import (
|
|||
from weakref import WeakValueDictionary
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
|
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
|
|||
|
||||
|
||||
class _CacheDescriptorBase:
|
||||
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[..., Any],
|
||||
num_args: Optional[int],
|
||||
cache_context: bool = False,
|
||||
):
|
||||
self.orig = orig
|
||||
|
||||
arg_spec = inspect.getfullargspec(orig)
|
||||
|
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
orig,
|
||||
orig: Callable[..., Any],
|
||||
max_entries: int = 1000,
|
||||
cache_context: bool = False,
|
||||
):
|
||||
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||
self.max_entries = max_entries
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||
cache: LruCache[CacheKey, Any] = LruCache(
|
||||
cache_name=self.orig.__name__,
|
||||
max_size=self.max_entries,
|
||||
|
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|||
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||
|
||||
|
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
return r1 + r2
|
||||
|
||||
Args:
|
||||
num_args (int): number of positional arguments (excluding ``self`` and
|
||||
num_args: number of positional arguments (excluding ``self`` and
|
||||
``cache_context``) to use as cache keys. Defaults to all named
|
||||
args of the function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orig,
|
||||
max_entries=1000,
|
||||
num_args=None,
|
||||
tree=False,
|
||||
cache_context=False,
|
||||
iterable=False,
|
||||
orig: Callable[..., Any],
|
||||
max_entries: int = 1000,
|
||||
num_args: Optional[int] = None,
|
||||
tree: bool = False,
|
||||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
prune_unread_entries: bool = True,
|
||||
):
|
||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||
|
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
self.iterable = iterable
|
||||
self.prune_unread_entries = prune_unread_entries
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
|
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
get_cache_key = self.cache_key_builder
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
of results.
|
||||
"""
|
||||
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=None):
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[..., Any],
|
||||
cached_method_name: str,
|
||||
list_name: str,
|
||||
num_args: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
orig (function)
|
||||
cached_method_name (str): The name of the cached method.
|
||||
list_name (str): Name of the argument which is the bulk lookup list
|
||||
num_args (int): number of positional arguments (excluding ``self``,
|
||||
orig
|
||||
cached_method_name: The name of the cached method.
|
||||
list_name: Name of the argument which is the bulk lookup list
|
||||
num_args: number of positional arguments (excluding ``self``,
|
||||
but including list_name) to use as cache keys. Defaults to all
|
||||
named args of the function.
|
||||
"""
|
||||
|
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
% (self.list_name, cached_method_name)
|
||||
)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
def __get__(
|
||||
self, obj: Optional[Any], objtype: Optional[Type] = None
|
||||
) -> Callable[..., Any]:
|
||||
cached_method = getattr(obj, self.cached_method_name)
|
||||
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
||||
num_args = cached_method.num_args
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
# If we're passed a cache_context then we'll want to call its
|
||||
# invalidate() whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
|
||||
results = {}
|
||||
|
||||
def update_results_dict(res, arg):
|
||||
def update_results_dict(res: Any, arg: Hashable) -> None:
|
||||
results[arg] = res
|
||||
|
||||
# list of deferreds to wait for
|
||||
|
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
# otherwise a tuple is used.
|
||||
if num_args == 1:
|
||||
|
||||
def arg_to_cache_key(arg):
|
||||
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||
return arg
|
||||
|
||||
else:
|
||||
keylist = list(keyargs)
|
||||
|
||||
def arg_to_cache_key(arg):
|
||||
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||
keylist[self.list_pos] = arg
|
||||
return tuple(keylist)
|
||||
|
||||
|
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
key = arg_to_cache_key(arg)
|
||||
cache.set(key, deferred, callback=invalidate_callback)
|
||||
|
||||
def complete_all(res):
|
||||
def complete_all(res: Dict[Hashable, Any]) -> None:
|
||||
# the wrapped function has completed. It returns a
|
||||
# a dict. We can now resolve the observable deferreds in
|
||||
# the cache and update our own result map.
|
||||
|
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
deferreds_map[e].callback(val)
|
||||
results[e] = val
|
||||
|
||||
def errback(f):
|
||||
def errback(f: Failure) -> Failure:
|
||||
# the wrapped function has failed. Invalidate any cache
|
||||
# entries we're supposed to be populating, and fail
|
||||
# their deferreds.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue