Add types to synapse.util. (#10601)

This commit is contained in:
reivilibre 2021-09-10 17:03:18 +01:00 committed by GitHub
parent ceab5a4bfa
commit 524b8ead77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 400 additions and 253 deletions

View file

@ -15,33 +15,36 @@
import collections
import contextlib
import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.util import Clock
if typing.TYPE_CHECKING:
from contextlib import _GeneratorContextManager
logger = logging.getLogger(__name__)
class FederationRateLimiter:
def __init__(self, clock, config):
"""
Args:
clock (Clock)
config (FederationRateLimitConfig)
"""
def new_limiter():
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(clock=clock, config=config)
self.ratelimiters = collections.defaultdict(new_limiter)
self.ratelimiters: DefaultDict[
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
def ratelimit(self, host):
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
Example usage:
@ -60,11 +63,11 @@ class FederationRateLimiter:
class _PerHostRatelimiter:
def __init__(self, clock, config):
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
"""
Args:
clock (Clock)
config (FederationRateLimitConfig)
clock
config
"""
self.clock = clock
@ -75,21 +78,23 @@ class _PerHostRatelimiter:
self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests = set()
self.sleeping_requests: Set[object] = set()
# map from request_id object to Deferred for requests which are ready
# for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
self.ready_request_queue: collections.OrderedDict[
object, defer.Deferred[None]
] = collections.OrderedDict()
# request id objects for requests which are in progress
self.current_processing = set()
self.current_processing: Set[object] = set()
# times at which we have recently (within the last window_size ms)
# received requests.
self.request_times = []
self.request_times: List[int] = []
@contextlib.contextmanager
def ratelimit(self):
def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
@ -102,7 +107,7 @@ class _PerHostRatelimiter:
finally:
self._on_exit(request_id)
def _on_enter(self, request_id):
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window
@ -120,9 +125,9 @@ class _PerHostRatelimiter:
self.request_times.append(time_now)
def queue_request():
def queue_request() -> "defer.Deferred[None]":
if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred()
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
"Ratelimiter: queueing request (queue now %i items)",
@ -145,7 +150,7 @@ class _PerHostRatelimiter:
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
@ -155,19 +160,19 @@ class _PerHostRatelimiter:
else:
ret_defer = queue_request()
def on_start(r):
def on_start(r: object) -> object:
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
def on_err(r):
def on_err(r: object) -> object:
# XXX: why is this necessary? this is called before we start
# processing the request so why would the request be in
# current_processing?
self.current_processing.discard(request_id)
return r
def on_both(r):
def on_both(r: object) -> object:
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
@ -177,7 +182,7 @@ class _PerHostRatelimiter:
ret_defer.addBoth(on_both)
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try: