mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-03 16:04:50 -04:00
Add types to synapse.util. (#10601)
This commit is contained in:
parent
ceab5a4bfa
commit
524b8ead77
41 changed files with 400 additions and 253 deletions
|
@ -15,33 +15,36 @@
|
|||
import collections
|
||||
import contextlib
|
||||
import logging
|
||||
import typing
|
||||
from typing import Any, DefaultDict, Iterator, List, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from contextlib import _GeneratorContextManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationRateLimiter:
|
||||
def __init__(self, clock, config):
|
||||
"""
|
||||
Args:
|
||||
clock (Clock)
|
||||
config (FederationRateLimitConfig)
|
||||
"""
|
||||
|
||||
def new_limiter():
|
||||
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
|
||||
def new_limiter() -> "_PerHostRatelimiter":
|
||||
return _PerHostRatelimiter(clock=clock, config=config)
|
||||
|
||||
self.ratelimiters = collections.defaultdict(new_limiter)
|
||||
self.ratelimiters: DefaultDict[
|
||||
str, "_PerHostRatelimiter"
|
||||
] = collections.defaultdict(new_limiter)
|
||||
|
||||
def ratelimit(self, host):
|
||||
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
|
||||
"""Used to ratelimit an incoming request from a given host
|
||||
|
||||
Example usage:
|
||||
|
@ -60,11 +63,11 @@ class FederationRateLimiter:
|
|||
|
||||
|
||||
class _PerHostRatelimiter:
|
||||
def __init__(self, clock, config):
|
||||
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
|
||||
"""
|
||||
Args:
|
||||
clock (Clock)
|
||||
config (FederationRateLimitConfig)
|
||||
clock
|
||||
config
|
||||
"""
|
||||
self.clock = clock
|
||||
|
||||
|
@ -75,21 +78,23 @@ class _PerHostRatelimiter:
|
|||
self.concurrent_requests = config.concurrent
|
||||
|
||||
# request_id objects for requests which have been slept
|
||||
self.sleeping_requests = set()
|
||||
self.sleeping_requests: Set[object] = set()
|
||||
|
||||
# map from request_id object to Deferred for requests which are ready
|
||||
# for processing but have been queued
|
||||
self.ready_request_queue = collections.OrderedDict()
|
||||
self.ready_request_queue: collections.OrderedDict[
|
||||
object, defer.Deferred[None]
|
||||
] = collections.OrderedDict()
|
||||
|
||||
# request id objects for requests which are in progress
|
||||
self.current_processing = set()
|
||||
self.current_processing: Set[object] = set()
|
||||
|
||||
# times at which we have recently (within the last window_size ms)
|
||||
# received requests.
|
||||
self.request_times = []
|
||||
self.request_times: List[int] = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ratelimit(self):
|
||||
def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
|
||||
# `contextlib.contextmanager` takes a generator and turns it into a
|
||||
# context manager. The generator should only yield once with a value
|
||||
# to be returned by manager.
|
||||
|
@ -102,7 +107,7 @@ class _PerHostRatelimiter:
|
|||
finally:
|
||||
self._on_exit(request_id)
|
||||
|
||||
def _on_enter(self, request_id):
|
||||
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# remove any entries from request_times which aren't within the window
|
||||
|
@ -120,9 +125,9 @@ class _PerHostRatelimiter:
|
|||
|
||||
self.request_times.append(time_now)
|
||||
|
||||
def queue_request():
|
||||
def queue_request() -> "defer.Deferred[None]":
|
||||
if len(self.current_processing) >= self.concurrent_requests:
|
||||
queue_defer = defer.Deferred()
|
||||
queue_defer: defer.Deferred[None] = defer.Deferred()
|
||||
self.ready_request_queue[request_id] = queue_defer
|
||||
logger.info(
|
||||
"Ratelimiter: queueing request (queue now %i items)",
|
||||
|
@ -145,7 +150,7 @@ class _PerHostRatelimiter:
|
|||
|
||||
self.sleeping_requests.add(request_id)
|
||||
|
||||
def on_wait_finished(_):
|
||||
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
|
||||
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
|
||||
self.sleeping_requests.discard(request_id)
|
||||
queue_defer = queue_request()
|
||||
|
@ -155,19 +160,19 @@ class _PerHostRatelimiter:
|
|||
else:
|
||||
ret_defer = queue_request()
|
||||
|
||||
def on_start(r):
|
||||
def on_start(r: object) -> object:
|
||||
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
|
||||
self.current_processing.add(request_id)
|
||||
return r
|
||||
|
||||
def on_err(r):
|
||||
def on_err(r: object) -> object:
|
||||
# XXX: why is this necessary? this is called before we start
|
||||
# processing the request so why would the request be in
|
||||
# current_processing?
|
||||
self.current_processing.discard(request_id)
|
||||
return r
|
||||
|
||||
def on_both(r):
|
||||
def on_both(r: object) -> object:
|
||||
# Ensure that we've properly cleaned up.
|
||||
self.sleeping_requests.discard(request_id)
|
||||
self.ready_request_queue.pop(request_id, None)
|
||||
|
@ -177,7 +182,7 @@ class _PerHostRatelimiter:
|
|||
ret_defer.addBoth(on_both)
|
||||
return make_deferred_yieldable(ret_defer)
|
||||
|
||||
def _on_exit(self, request_id):
|
||||
def _on_exit(self, request_id: object) -> None:
|
||||
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
|
||||
self.current_processing.discard(request_id)
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue