mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
398 lines
13 KiB
Python
398 lines
13 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2015, 2016 OpenMarket Ltd
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
|
|
import collections
|
|
import contextlib
|
|
import logging
|
|
import threading
|
|
import typing
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ContextManager,
|
|
DefaultDict,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Mapping,
|
|
MutableSet,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
)
|
|
from weakref import WeakSet
|
|
|
|
from prometheus_client.core import Counter
|
|
|
|
from twisted.internet import defer
|
|
|
|
from synapse.api.errors import LimitExceededError
|
|
from synapse.config.ratelimiting import FederationRatelimitSettings
|
|
from synapse.logging.context import (
|
|
PreserveLoggingContext,
|
|
make_deferred_yieldable,
|
|
run_in_background,
|
|
)
|
|
from synapse.logging.opentracing import start_active_span
|
|
from synapse.metrics import Histogram, LaterGauge
|
|
from synapse.util import Clock
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from contextlib import _GeneratorContextManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Track how much the ratelimiter is affecting requests
|
|
rate_limit_sleep_counter = Counter(
|
|
"synapse_rate_limit_sleep",
|
|
"Number of requests slept by the rate limiter",
|
|
["rate_limiter_name"],
|
|
)
|
|
rate_limit_reject_counter = Counter(
|
|
"synapse_rate_limit_reject",
|
|
"Number of requests rejected by the rate limiter",
|
|
["rate_limiter_name"],
|
|
)
|
|
queue_wait_timer = Histogram(
|
|
"synapse_rate_limit_queue_wait_time_seconds",
|
|
"Amount of time spent waiting for the rate limiter to let our request through.",
|
|
["rate_limiter_name"],
|
|
buckets=(
|
|
0.005,
|
|
0.01,
|
|
0.025,
|
|
0.05,
|
|
0.1,
|
|
0.25,
|
|
0.5,
|
|
0.75,
|
|
1.0,
|
|
2.5,
|
|
5.0,
|
|
10.0,
|
|
20.0,
|
|
"+Inf",
|
|
),
|
|
)
|
|
|
|
|
|
# This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s
|
|
# during trial test runs and leak a lot of memory.
|
|
_rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet()
|
|
# Protects the _rate_limiter_instances set from concurrent access
|
|
_rate_limiter_instances_lock = threading.Lock()
|
|
|
|
|
|
def _get_counts_from_rate_limiter_instance(
|
|
count_func: Callable[["FederationRateLimiter"], int]
|
|
) -> Mapping[Tuple[str, ...], int]:
|
|
"""Returns a count of something (slept/rejected hosts) by (metrics_name)"""
|
|
# Cast to a list to prevent it changing while the Prometheus
|
|
# thread is collecting metrics
|
|
with _rate_limiter_instances_lock:
|
|
rate_limiter_instances = list(_rate_limiter_instances)
|
|
|
|
# Map from (metrics_name,) -> int, the number of something like slept hosts
|
|
# or rejected hosts. The key type is Tuple[str], but we leave the length
|
|
# unspecified for compatability with LaterGauge's annotations.
|
|
counts: Dict[Tuple[str, ...], int] = {}
|
|
for rate_limiter_instance in rate_limiter_instances:
|
|
# Only track metrics if they provided a `metrics_name` to
|
|
# differentiate this instance of the rate limiter.
|
|
if rate_limiter_instance.metrics_name:
|
|
key = (rate_limiter_instance.metrics_name,)
|
|
counts[key] = count_func(rate_limiter_instance)
|
|
|
|
return counts
|
|
|
|
|
|
# We track the number of affected hosts per time-period so we can
|
|
# differentiate one really noisy homeserver from a general
|
|
# ratelimit tuning problem across the federation.
|
|
LaterGauge(
|
|
"synapse_rate_limit_sleep_affected_hosts",
|
|
"Number of hosts that had requests put to sleep",
|
|
["rate_limiter_name"],
|
|
lambda: _get_counts_from_rate_limiter_instance(
|
|
lambda rate_limiter_instance: sum(
|
|
ratelimiter.should_sleep()
|
|
for ratelimiter in rate_limiter_instance.ratelimiters.values()
|
|
)
|
|
),
|
|
)
|
|
LaterGauge(
|
|
"synapse_rate_limit_reject_affected_hosts",
|
|
"Number of hosts that had requests rejected",
|
|
["rate_limiter_name"],
|
|
lambda: _get_counts_from_rate_limiter_instance(
|
|
lambda rate_limiter_instance: sum(
|
|
ratelimiter.should_reject()
|
|
for ratelimiter in rate_limiter_instance.ratelimiters.values()
|
|
)
|
|
),
|
|
)
|
|
|
|
|
|
class FederationRateLimiter:
|
|
"""Used to rate limit request per-host."""
|
|
|
|
def __init__(
|
|
self,
|
|
clock: Clock,
|
|
config: FederationRatelimitSettings,
|
|
metrics_name: Optional[str] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
clock
|
|
config
|
|
metrics_name: The name of the rate limiter so we can differentiate it
|
|
from the rest in the metrics. If `None`, we don't track metrics
|
|
for this rate limiter.
|
|
|
|
"""
|
|
self.metrics_name = metrics_name
|
|
|
|
def new_limiter() -> "_PerHostRatelimiter":
|
|
return _PerHostRatelimiter(
|
|
clock=clock, config=config, metrics_name=metrics_name
|
|
)
|
|
|
|
self.ratelimiters: DefaultDict[
|
|
str, "_PerHostRatelimiter"
|
|
] = collections.defaultdict(new_limiter)
|
|
|
|
with _rate_limiter_instances_lock:
|
|
_rate_limiter_instances.add(self)
|
|
|
|
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
|
|
"""Used to ratelimit an incoming request from a given host
|
|
|
|
Example usage:
|
|
|
|
with rate_limiter.ratelimit(origin) as wait_deferred:
|
|
yield wait_deferred
|
|
# Handle request ...
|
|
|
|
Args:
|
|
host: Origin of incoming request.
|
|
|
|
Returns:
|
|
context manager which returns a deferred.
|
|
"""
|
|
return self.ratelimiters[host].ratelimit(host)
|
|
|
|
|
|
class _PerHostRatelimiter:
|
|
def __init__(
|
|
self,
|
|
clock: Clock,
|
|
config: FederationRatelimitSettings,
|
|
metrics_name: Optional[str] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
clock
|
|
config
|
|
metrics_name: The name of the rate limiter so we can differentiate it
|
|
from the rest in the metrics. If `None`, we don't track metrics
|
|
for this rate limiter.
|
|
from the rest in the metrics
|
|
"""
|
|
self.clock = clock
|
|
self.metrics_name = metrics_name
|
|
|
|
self.window_size = config.window_size
|
|
self.sleep_limit = config.sleep_limit
|
|
self.sleep_sec = config.sleep_delay / 1000.0
|
|
self.reject_limit = config.reject_limit
|
|
self.concurrent_requests = config.concurrent
|
|
|
|
# request_id objects for requests which have been slept
|
|
self.sleeping_requests: Set[object] = set()
|
|
|
|
# map from request_id object to Deferred for requests which are ready
|
|
# for processing but have been queued
|
|
self.ready_request_queue: collections.OrderedDict[
|
|
object, defer.Deferred[None]
|
|
] = collections.OrderedDict()
|
|
|
|
# request id objects for requests which are in progress
|
|
self.current_processing: Set[object] = set()
|
|
|
|
# times at which we have recently (within the last window_size ms)
|
|
# received requests.
|
|
self.request_times: List[int] = []
|
|
|
|
@contextlib.contextmanager
|
|
def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
|
|
# `contextlib.contextmanager` takes a generator and turns it into a
|
|
# context manager. The generator should only yield once with a value
|
|
# to be returned by manager.
|
|
# Exceptions will be reraised at the yield.
|
|
|
|
self.host = host
|
|
|
|
request_id = object()
|
|
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
|
|
# type-checking, but we'd need Twisted >= 21.2.
|
|
ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
|
|
try:
|
|
yield ret
|
|
finally:
|
|
self._on_exit(request_id)
|
|
|
|
def should_reject(self) -> bool:
|
|
"""
|
|
Whether to reject the request if we already have too many queued up
|
|
(either sleeping or in the ready queue).
|
|
"""
|
|
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
|
|
return queue_size > self.reject_limit
|
|
|
|
def should_sleep(self) -> bool:
|
|
"""
|
|
Whether to sleep the request if we already have too many requests coming
|
|
through within the window.
|
|
"""
|
|
return len(self.request_times) > self.sleep_limit
|
|
|
|
async def _on_enter_with_tracing(self, request_id: object) -> None:
|
|
maybe_metrics_cm: ContextManager = contextlib.nullcontext()
|
|
if self.metrics_name:
|
|
maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
|
|
with start_active_span("ratelimit wait"), maybe_metrics_cm:
|
|
await self._on_enter(request_id)
|
|
|
|
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
|
|
time_now = self.clock.time_msec()
|
|
|
|
# remove any entries from request_times which aren't within the window
|
|
self.request_times[:] = [
|
|
r for r in self.request_times if time_now - r < self.window_size
|
|
]
|
|
|
|
# reject the request if we already have too many queued up (either
|
|
# sleeping or in the ready queue).
|
|
if self.should_reject():
|
|
logger.debug("Ratelimiter(%s): rejecting request", self.host)
|
|
if self.metrics_name:
|
|
rate_limit_reject_counter.labels(self.metrics_name).inc()
|
|
raise LimitExceededError(
|
|
limiter_name="rc_federation",
|
|
retry_after_ms=int(self.window_size / self.sleep_limit),
|
|
)
|
|
|
|
self.request_times.append(time_now)
|
|
|
|
def queue_request() -> "defer.Deferred[None]":
|
|
if len(self.current_processing) >= self.concurrent_requests:
|
|
queue_defer: defer.Deferred[None] = defer.Deferred()
|
|
self.ready_request_queue[request_id] = queue_defer
|
|
logger.info(
|
|
"Ratelimiter(%s): queueing request (queue now %i items)",
|
|
self.host,
|
|
len(self.ready_request_queue),
|
|
)
|
|
|
|
return queue_defer
|
|
else:
|
|
return defer.succeed(None)
|
|
|
|
logger.debug(
|
|
"Ratelimit(%s) [%s]: len(self.request_times)=%d",
|
|
self.host,
|
|
id(request_id),
|
|
len(self.request_times),
|
|
)
|
|
|
|
if self.should_sleep():
|
|
logger.debug(
|
|
"Ratelimiter(%s) [%s]: sleeping request for %f sec",
|
|
self.host,
|
|
id(request_id),
|
|
self.sleep_sec,
|
|
)
|
|
if self.metrics_name:
|
|
rate_limit_sleep_counter.labels(self.metrics_name).inc()
|
|
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
|
|
|
|
self.sleeping_requests.add(request_id)
|
|
|
|
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
|
|
logger.debug(
|
|
"Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
|
|
)
|
|
self.sleeping_requests.discard(request_id)
|
|
queue_defer = queue_request()
|
|
return queue_defer
|
|
|
|
ret_defer.addBoth(on_wait_finished)
|
|
else:
|
|
ret_defer = queue_request()
|
|
|
|
def on_start(r: object) -> object:
|
|
logger.debug(
|
|
"Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
|
|
)
|
|
self.current_processing.add(request_id)
|
|
return r
|
|
|
|
def on_err(r: object) -> object:
|
|
# XXX: why is this necessary? this is called before we start
|
|
# processing the request so why would the request be in
|
|
# current_processing?
|
|
self.current_processing.discard(request_id)
|
|
return r
|
|
|
|
def on_both(r: object) -> object:
|
|
# Ensure that we've properly cleaned up.
|
|
self.sleeping_requests.discard(request_id)
|
|
self.ready_request_queue.pop(request_id, None)
|
|
return r
|
|
|
|
ret_defer.addCallbacks(on_start, on_err)
|
|
ret_defer.addBoth(on_both)
|
|
return make_deferred_yieldable(ret_defer)
|
|
|
|
def _on_exit(self, request_id: object) -> None:
|
|
logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
|
|
|
|
# When requests complete synchronously, we will recursively start the next
|
|
# request in the queue. To avoid stack exhaustion, we defer starting the next
|
|
# request until the next reactor tick.
|
|
|
|
def start_next_request() -> None:
|
|
# We only remove the completed request from the list when we're about to
|
|
# start the next one, otherwise we can allow extra requests through.
|
|
self.current_processing.discard(request_id)
|
|
try:
|
|
# start processing the next item on the queue.
|
|
_, deferred = self.ready_request_queue.popitem(last=False)
|
|
|
|
with PreserveLoggingContext():
|
|
deferred.callback(None)
|
|
except KeyError:
|
|
pass
|
|
|
|
self.clock.call_later(0.0, start_next_request)
|