Add types to async_helpers (#8260)

This commit is contained in:
Patrick Cloke 2020-09-08 16:50:51 -04:00 committed by GitHub
parent 1553adc831
commit e45b834119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 51 deletions

1
changelog.d/8260.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.util.async_helpers`.

View File

@ -34,7 +34,7 @@ files =
synapse/http/federation/well_known_resolver.py, synapse/http/federation/well_known_resolver.py,
synapse/http/server.py, synapse/http/server.py,
synapse/http/site.py, synapse/http/site.py,
synapse/logging/, synapse/logging,
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,
synapse/notifier.py, synapse/notifier.py,
@ -54,6 +54,7 @@ files =
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py, synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py, synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py, synapse/util/metrics.py,

View File

@ -17,13 +17,25 @@
import collections import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
TypeVar,
Union,
)
import attr import attr
from typing_extensions import ContextManager from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure from twisted.python import failure
from synapse.logging.context import ( from synapse.logging.context import (
@ -54,7 +66,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
def __init__(self, deferred, consumeErrors=False): def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
@ -111,25 +123,25 @@ class ObservableDeferred:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self): def observers(self) -> List[defer.Deferred]:
return self._observers return self._observers
def has_called(self): def has_called(self) -> bool:
return self._result is not None return self._result is not None
def has_succeeded(self): def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True return self._result is not None and self._result[0] is True
def get_result(self): def get_result(self) -> Any:
return self._result[1] return self._result[1]
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
return getattr(self._deferred, name) return getattr(self._deferred, name)
def __setattr__(self, name, value): def __setattr__(self, name: str, value: Any) -> None:
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self): def __repr__(self) -> str:
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), id(self),
self._result, self._result,
@ -137,18 +149,20 @@ class ObservableDeferred:
) )
def concurrently_execute(func, args, limit): def concurrently_execute(
"""Executes the function with each argument conncurrently while limiting func: Callable, args: Iterable[Any], limit: int
) -> defer.Deferred:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions. the number of concurrent executions.
Args: Args:
func (func): Function to execute, should return a deferred or coroutine. func: Function to execute, should return a deferred or coroutine.
args (Iterable): List of arguments to pass to func, each invocation of func args: List of arguments to pass to func, each invocation of func
gets a single argument. gets a single argument.
limit (int): Maximum number of conccurent executions. limit: Maximum number of conccurent executions.
Returns: Returns:
deferred: Resolved when all function invocations have finished. Deferred[list]: Resolved when all function invocations have finished.
""" """
it = iter(args) it = iter(args)
@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs): def yieldable_gather_results(
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
) -> defer.Deferred:
"""Executes the function with each argument concurrently. """Executes the function with each argument concurrently.
Args: Args:
func (func): Function to execute that returns a Deferred func: Function to execute that returns a Deferred
iter (iter): An iterable that yields items that get passed as the first iter: An iterable that yields items that get passed as the first
argument to the function argument to the function
*args: Arguments to be passed to each call to func *args: Arguments to be passed to each call to func
**kwargs: Keyword arguments to be passed to each call to func
Returns Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if Deferred[list]: Resolved when all functions have been invoked, or errors if
@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@attr.s(slots=True)
class _LinearizerEntry:
# The number of things executing.
count = attr.ib(type=int)
# Deferreds for the things blocked from executing.
deferreds = attr.ib(type=collections.OrderedDict)
class Linearizer: class Linearizer:
"""Limits concurrent access to resources based on a key. Useful to ensure """Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource. only a few things happen at a time on a given resource.
Example: Example:
with (yield limiter.queue("test_key")): with await limiter.queue("test_key"):
# do some work. # do some work.
""" """
def __init__(self, name=None, max_count=1, clock=None): def __init__(
self,
name: Optional[str] = None,
max_count: int = 1,
clock: Optional[Clock] = None,
):
""" """
Args: Args:
max_count(int): The maximum number of concurrent accesses max_count: The maximum number of concurrent accesses
""" """
if name is None: if name is None:
self.name = id(self) self.name = id(self) # type: Union[str, int]
else: else:
self.name = name self.name = name
@ -216,15 +246,10 @@ class Linearizer:
self._clock = clock self._clock = clock
self.max_count = max_count self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where # key_to_defer is a map from the key to a _LinearizerEntry.
# the first element is the number of things executing, and self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def is_queued(self, key) -> bool: def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting """Checks whether there is a process queued up waiting
""" """
entry = self.key_to_defer.get(key) entry = self.key_to_defer.get(key)
@ -234,25 +259,27 @@ class Linearizer:
# There are waiting deferreds only in the OrderedDict of deferreds is # There are waiting deferreds only in the OrderedDict of deferreds is
# non-empty. # non-empty.
return bool(entry[1]) return bool(entry.deferreds)
def queue(self, key): def queue(self, key: Hashable) -> defer.Deferred:
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
# propagated inside inlineCallbacks until Twisted 18.7) # propagated inside inlineCallbacks until Twisted 18.7)
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) entry = self.key_to_defer.setdefault(
key, _LinearizerEntry(0, collections.OrderedDict())
)
# If the number of things executing is greater than the maximum # If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items # then add a deferred to the list of blocked items
# When one of the things currently executing finishes it will callback # When one of the things currently executing finishes it will callback
# this item so that it can continue executing. # this item so that it can continue executing.
if entry[0] >= self.max_count: if entry.count >= self.max_count:
res = self._await_lock(key) res = self._await_lock(key)
else: else:
logger.debug( logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key "Acquired uncontended linearizer lock %r for key %r", self.name, key
) )
entry[0] += 1 entry.count += 1
res = defer.succeed(None) res = defer.succeed(None)
# once we successfully get the lock, we need to return a context manager which # once we successfully get the lock, we need to return a context manager which
@ -267,15 +294,15 @@ class Linearizer:
# We've finished executing so check if there are any things # We've finished executing so check if there are any things
# blocked waiting to execute and start one of them # blocked waiting to execute and start one of them
entry[0] -= 1 entry.count -= 1
if entry[1]: if entry.deferreds:
(next_def, _) = entry[1].popitem(last=False) (next_def, _) = entry.deferreds.popitem(last=False)
# we need to run the next thing in the sentinel context. # we need to run the next thing in the sentinel context.
with PreserveLoggingContext(): with PreserveLoggingContext():
next_def.callback(None) next_def.callback(None)
elif entry[0] == 0: elif entry.count == 0:
# We were the last thing for this key: remove it from the # We were the last thing for this key: remove it from the
# map. # map.
del self.key_to_defer[key] del self.key_to_defer[key]
@ -283,7 +310,7 @@ class Linearizer:
res.addCallback(_ctx_manager) res.addCallback(_ctx_manager)
return res return res
def _await_lock(self, key): def _await_lock(self, key: Hashable) -> defer.Deferred:
"""Helper for queue: adds a deferred to the queue """Helper for queue: adds a deferred to the queue
Assumes that we've already checked that we've reached the limit of the number Assumes that we've already checked that we've reached the limit of the number
@ -298,11 +325,11 @@ class Linearizer:
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred()) new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1 entry.deferreds[new_defer] = 1
def cb(_r): def cb(_r):
logger.debug("Acquired linearizer lock %r for key %r", self.name, key) logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1 entry.count += 1
# if the code holding the lock completes synchronously, then it # if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can # will recursively run the next claimant on the list. That can
@ -331,7 +358,7 @@ class Linearizer:
) )
# we just have to take ourselves back out of the queue. # we just have to take ourselves back out of the queue.
del entry[1][new_defer] del entry.deferreds[new_defer]
return e return e
new_defer.addCallbacks(cb, eb) new_defer.addCallbacks(cb, eb)
@ -419,14 +446,22 @@ class ReadWriteLock:
return _ctx_manager() return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout): R = TypeVar("R")
def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
if isinstance(value, failure.Failure): if isinstance(value, failure.Failure):
value.trap(CancelledError) value.trap(CancelledError)
raise defer.TimeoutError(timeout, "Deferred") raise defer.TimeoutError(timeout, "Deferred")
return value return value
def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): def timeout_deferred(
deferred: defer.Deferred,
timeout: float,
reactor: IReactorTime,
on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds """The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling deferred that wraps and times out the given deferred, correctly handling
@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
Args: Args:
deferred (Deferred) deferred: The Deferred to potentially timeout.
timeout (float): Timeout in seconds timeout: Timeout in seconds
reactor (twisted.interfaces.IReactorTime): The twisted reactor to use reactor: The twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately on_timeout_cancel: A callable which is called immediately
after the deferred times out, and not if this deferred is after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout. otherwise cancelled before the timeout.
@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
CancelledError Failure into a defer.TimeoutError. CancelledError Failure into a defer.TimeoutError.
Returns: Returns:
Deferred A new Deferred.
""" """
new_d = defer.Deferred() new_d = defer.Deferred()