mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-11-03 20:44:02 -05:00 
			
		
		
		
	Add types to async_helpers (#8260)
This commit is contained in:
		
							parent
							
								
									1553adc831
								
							
						
					
					
						commit
						e45b834119
					
				
					 3 changed files with 88 additions and 51 deletions
				
			
		
							
								
								
									
										1
									
								
								changelog.d/8260.misc
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/8260.misc
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Add type hints to `synapse.util.async_helpers`.
 | 
			
		||||
							
								
								
									
										3
									
								
								mypy.ini
									
										
									
									
									
								
							
							
						
						
									
										3
									
								
								mypy.ini
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -34,7 +34,7 @@ files =
 | 
			
		|||
  synapse/http/federation/well_known_resolver.py,
 | 
			
		||||
  synapse/http/server.py,
 | 
			
		||||
  synapse/http/site.py,
 | 
			
		||||
  synapse/logging/,
 | 
			
		||||
  synapse/logging,
 | 
			
		||||
  synapse/metrics,
 | 
			
		||||
  synapse/module_api,
 | 
			
		||||
  synapse/notifier.py,
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +54,7 @@ files =
 | 
			
		|||
  synapse/storage/util,
 | 
			
		||||
  synapse/streams,
 | 
			
		||||
  synapse/types.py,
 | 
			
		||||
  synapse/util/async_helpers.py,
 | 
			
		||||
  synapse/util/caches/descriptors.py,
 | 
			
		||||
  synapse/util/caches/stream_change_cache.py,
 | 
			
		||||
  synapse/util/metrics.py,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,13 +17,25 @@
 | 
			
		|||
import collections
 | 
			
		||||
import logging
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from typing import Dict, Sequence, Set, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Hashable,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Set,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
from typing_extensions import ContextManager
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
from twisted.internet.defer import CancelledError
 | 
			
		||||
from twisted.internet.interfaces import IReactorTime
 | 
			
		||||
from twisted.python import failure
 | 
			
		||||
 | 
			
		||||
from synapse.logging.context import (
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +66,7 @@ class ObservableDeferred:
 | 
			
		|||
 | 
			
		||||
    __slots__ = ["_deferred", "_observers", "_result"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, deferred, consumeErrors=False):
 | 
			
		||||
    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
 | 
			
		||||
        object.__setattr__(self, "_deferred", deferred)
 | 
			
		||||
        object.__setattr__(self, "_result", None)
 | 
			
		||||
        object.__setattr__(self, "_observers", set())
 | 
			
		||||
| 
						 | 
				
			
			@ -111,25 +123,25 @@ class ObservableDeferred:
 | 
			
		|||
            success, res = self._result
 | 
			
		||||
            return defer.succeed(res) if success else defer.fail(res)
 | 
			
		||||
 | 
			
		||||
    def observers(self):
 | 
			
		||||
    def observers(self) -> List[defer.Deferred]:
 | 
			
		||||
        return self._observers
 | 
			
		||||
 | 
			
		||||
    def has_called(self):
 | 
			
		||||
    def has_called(self) -> bool:
 | 
			
		||||
        return self._result is not None
 | 
			
		||||
 | 
			
		||||
    def has_succeeded(self):
 | 
			
		||||
    def has_succeeded(self) -> bool:
 | 
			
		||||
        return self._result is not None and self._result[0] is True
 | 
			
		||||
 | 
			
		||||
    def get_result(self):
 | 
			
		||||
    def get_result(self) -> Any:
 | 
			
		||||
        return self._result[1]
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
    def __getattr__(self, name: str) -> Any:
 | 
			
		||||
        return getattr(self._deferred, name)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name, value):
 | 
			
		||||
    def __setattr__(self, name: str, value: Any) -> None:
 | 
			
		||||
        setattr(self._deferred, name, value)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
 | 
			
		||||
            id(self),
 | 
			
		||||
            self._result,
 | 
			
		||||
| 
						 | 
				
			
			@ -137,18 +149,20 @@ class ObservableDeferred:
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def concurrently_execute(func, args, limit):
 | 
			
		||||
    """Executes the function with each argument conncurrently while limiting
 | 
			
		||||
def concurrently_execute(
 | 
			
		||||
    func: Callable, args: Iterable[Any], limit: int
 | 
			
		||||
) -> defer.Deferred:
 | 
			
		||||
    """Executes the function with each argument concurrently while limiting
 | 
			
		||||
    the number of concurrent executions.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        func (func): Function to execute, should return a deferred or coroutine.
 | 
			
		||||
        args (Iterable): List of arguments to pass to func, each invocation of func
 | 
			
		||||
        func: Function to execute, should return a deferred or coroutine.
 | 
			
		||||
        args: List of arguments to pass to func, each invocation of func
 | 
			
		||||
            gets a single argument.
 | 
			
		||||
        limit (int): Maximum number of conccurent executions.
 | 
			
		||||
        limit: Maximum number of conccurent executions.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        deferred: Resolved when all function invocations have finished.
 | 
			
		||||
        Deferred[list]: Resolved when all function invocations have finished.
 | 
			
		||||
    """
 | 
			
		||||
    it = iter(args)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
 | 
			
		|||
    ).addErrback(unwrapFirstError)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def yieldable_gather_results(func, iter, *args, **kwargs):
 | 
			
		||||
def yieldable_gather_results(
 | 
			
		||||
    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
 | 
			
		||||
) -> defer.Deferred:
 | 
			
		||||
    """Executes the function with each argument concurrently.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        func (func): Function to execute that returns a Deferred
 | 
			
		||||
        iter (iter): An iterable that yields items that get passed as the first
 | 
			
		||||
        func: Function to execute that returns a Deferred
 | 
			
		||||
        iter: An iterable that yields items that get passed as the first
 | 
			
		||||
            argument to the function
 | 
			
		||||
        *args: Arguments to be passed to each call to func
 | 
			
		||||
        **kwargs: Keyword arguments to be passed to each call to func
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
        Deferred[list]: Resolved when all functions have been invoked, or errors if
 | 
			
		||||
| 
						 | 
				
			
			@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
 | 
			
		|||
    ).addErrback(unwrapFirstError)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@attr.s(slots=True)
 | 
			
		||||
class _LinearizerEntry:
 | 
			
		||||
    # The number of things executing.
 | 
			
		||||
    count = attr.ib(type=int)
 | 
			
		||||
    # Deferreds for the things blocked from executing.
 | 
			
		||||
    deferreds = attr.ib(type=collections.OrderedDict)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Linearizer:
 | 
			
		||||
    """Limits concurrent access to resources based on a key. Useful to ensure
 | 
			
		||||
    only a few things happen at a time on a given resource.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
        with (yield limiter.queue("test_key")):
 | 
			
		||||
        with await limiter.queue("test_key"):
 | 
			
		||||
            # do some work.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name=None, max_count=1, clock=None):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        name: Optional[str] = None,
 | 
			
		||||
        max_count: int = 1,
 | 
			
		||||
        clock: Optional[Clock] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            max_count(int): The maximum number of concurrent accesses
 | 
			
		||||
            max_count: The maximum number of concurrent accesses
 | 
			
		||||
        """
 | 
			
		||||
        if name is None:
 | 
			
		||||
            self.name = id(self)
 | 
			
		||||
            self.name = id(self)  # type: Union[str, int]
 | 
			
		||||
        else:
 | 
			
		||||
            self.name = name
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -216,15 +246,10 @@ class Linearizer:
 | 
			
		|||
        self._clock = clock
 | 
			
		||||
        self.max_count = max_count
 | 
			
		||||
 | 
			
		||||
        # key_to_defer is a map from the key to a 2 element list where
 | 
			
		||||
        # the first element is the number of things executing, and
 | 
			
		||||
        # the second element is an OrderedDict, where the keys are deferreds for the
 | 
			
		||||
        # things blocked from executing.
 | 
			
		||||
        self.key_to_defer = (
 | 
			
		||||
            {}
 | 
			
		||||
        )  # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
 | 
			
		||||
        # key_to_defer is a map from the key to a _LinearizerEntry.
 | 
			
		||||
        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
 | 
			
		||||
 | 
			
		||||
    def is_queued(self, key) -> bool:
 | 
			
		||||
    def is_queued(self, key: Hashable) -> bool:
 | 
			
		||||
        """Checks whether there is a process queued up waiting
 | 
			
		||||
        """
 | 
			
		||||
        entry = self.key_to_defer.get(key)
 | 
			
		||||
| 
						 | 
				
			
			@ -234,25 +259,27 @@ class Linearizer:
 | 
			
		|||
 | 
			
		||||
        # There are waiting deferreds only in the OrderedDict of deferreds is
 | 
			
		||||
        # non-empty.
 | 
			
		||||
        return bool(entry[1])
 | 
			
		||||
        return bool(entry.deferreds)
 | 
			
		||||
 | 
			
		||||
    def queue(self, key):
 | 
			
		||||
    def queue(self, key: Hashable) -> defer.Deferred:
 | 
			
		||||
        # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
 | 
			
		||||
        # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
 | 
			
		||||
        # propagated inside inlineCallbacks until Twisted 18.7)
 | 
			
		||||
        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
 | 
			
		||||
        entry = self.key_to_defer.setdefault(
 | 
			
		||||
            key, _LinearizerEntry(0, collections.OrderedDict())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # If the number of things executing is greater than the maximum
 | 
			
		||||
        # then add a deferred to the list of blocked items
 | 
			
		||||
        # When one of the things currently executing finishes it will callback
 | 
			
		||||
        # this item so that it can continue executing.
 | 
			
		||||
        if entry[0] >= self.max_count:
 | 
			
		||||
        if entry.count >= self.max_count:
 | 
			
		||||
            res = self._await_lock(key)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.debug(
 | 
			
		||||
                "Acquired uncontended linearizer lock %r for key %r", self.name, key
 | 
			
		||||
            )
 | 
			
		||||
            entry[0] += 1
 | 
			
		||||
            entry.count += 1
 | 
			
		||||
            res = defer.succeed(None)
 | 
			
		||||
 | 
			
		||||
        # once we successfully get the lock, we need to return a context manager which
 | 
			
		||||
| 
						 | 
				
			
			@ -267,15 +294,15 @@ class Linearizer:
 | 
			
		|||
 | 
			
		||||
                # We've finished executing so check if there are any things
 | 
			
		||||
                # blocked waiting to execute and start one of them
 | 
			
		||||
                entry[0] -= 1
 | 
			
		||||
                entry.count -= 1
 | 
			
		||||
 | 
			
		||||
                if entry[1]:
 | 
			
		||||
                    (next_def, _) = entry[1].popitem(last=False)
 | 
			
		||||
                if entry.deferreds:
 | 
			
		||||
                    (next_def, _) = entry.deferreds.popitem(last=False)
 | 
			
		||||
 | 
			
		||||
                    # we need to run the next thing in the sentinel context.
 | 
			
		||||
                    with PreserveLoggingContext():
 | 
			
		||||
                        next_def.callback(None)
 | 
			
		||||
                elif entry[0] == 0:
 | 
			
		||||
                elif entry.count == 0:
 | 
			
		||||
                    # We were the last thing for this key: remove it from the
 | 
			
		||||
                    # map.
 | 
			
		||||
                    del self.key_to_defer[key]
 | 
			
		||||
| 
						 | 
				
			
			@ -283,7 +310,7 @@ class Linearizer:
 | 
			
		|||
        res.addCallback(_ctx_manager)
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def _await_lock(self, key):
 | 
			
		||||
    def _await_lock(self, key: Hashable) -> defer.Deferred:
 | 
			
		||||
        """Helper for queue: adds a deferred to the queue
 | 
			
		||||
 | 
			
		||||
        Assumes that we've already checked that we've reached the limit of the number
 | 
			
		||||
| 
						 | 
				
			
			@ -298,11 +325,11 @@ class Linearizer:
 | 
			
		|||
        logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 | 
			
		||||
 | 
			
		||||
        new_defer = make_deferred_yieldable(defer.Deferred())
 | 
			
		||||
        entry[1][new_defer] = 1
 | 
			
		||||
        entry.deferreds[new_defer] = 1
 | 
			
		||||
 | 
			
		||||
        def cb(_r):
 | 
			
		||||
            logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
 | 
			
		||||
            entry[0] += 1
 | 
			
		||||
            entry.count += 1
 | 
			
		||||
 | 
			
		||||
            # if the code holding the lock completes synchronously, then it
 | 
			
		||||
            # will recursively run the next claimant on the list. That can
 | 
			
		||||
| 
						 | 
				
			
			@ -331,7 +358,7 @@ class Linearizer:
 | 
			
		|||
                )
 | 
			
		||||
 | 
			
		||||
            # we just have to take ourselves back out of the queue.
 | 
			
		||||
            del entry[1][new_defer]
 | 
			
		||||
            del entry.deferreds[new_defer]
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
        new_defer.addCallbacks(cb, eb)
 | 
			
		||||
| 
						 | 
				
			
			@ -419,14 +446,22 @@ class ReadWriteLock:
 | 
			
		|||
        return _ctx_manager()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cancelled_to_timed_out_error(value, timeout):
 | 
			
		||||
R = TypeVar("R")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
 | 
			
		||||
    if isinstance(value, failure.Failure):
 | 
			
		||||
        value.trap(CancelledError)
 | 
			
		||||
        raise defer.TimeoutError(timeout, "Deferred")
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
 | 
			
		||||
def timeout_deferred(
 | 
			
		||||
    deferred: defer.Deferred,
 | 
			
		||||
    timeout: float,
 | 
			
		||||
    reactor: IReactorTime,
 | 
			
		||||
    on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
 | 
			
		||||
) -> defer.Deferred:
 | 
			
		||||
    """The in built twisted `Deferred.addTimeout` fails to time out deferreds
 | 
			
		||||
    that have a canceller that throws exceptions. This method creates a new
 | 
			
		||||
    deferred that wraps and times out the given deferred, correctly handling
 | 
			
		||||
| 
						 | 
				
			
			@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
 | 
			
		|||
    NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        deferred (Deferred)
 | 
			
		||||
        timeout (float): Timeout in seconds
 | 
			
		||||
        reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
 | 
			
		||||
        on_timeout_cancel (callable): A callable which is called immediately
 | 
			
		||||
        deferred: The Deferred to potentially timeout.
 | 
			
		||||
        timeout: Timeout in seconds
 | 
			
		||||
        reactor: The twisted reactor to use
 | 
			
		||||
        on_timeout_cancel: A callable which is called immediately
 | 
			
		||||
            after the deferred times out, and not if this deferred is
 | 
			
		||||
            otherwise cancelled before the timeout.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
 | 
			
		|||
            CancelledError Failure into a defer.TimeoutError.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Deferred
 | 
			
		||||
        A new Deferred.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    new_d = defer.Deferred()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue