anonymousland-synapse/synapse/util/ratelimitutils.py
Sean Quah 772e8c2385
Fix stack overflow in _PerHostRatelimiter due to synchronous requests (#14812)
When there are many synchronous requests waiting on a
`_PerHostRatelimiter`, each request will be started recursively just
after the previous request has completed. Under the right conditions,
this leads to stack exhaustion.

A common way for requests to become synchronous is when the remote
client disconnects early, because the homeserver is overloaded and slow
to respond.

Avoid stack exhaustion under these conditions by deferring subsequent
requests until the next reactor tick.

Fixes #14480.

Signed-off-by: Sean Quah <seanq@matrix.org>
2023-01-13 00:16:21 +00:00

392 lines
13 KiB
Python

# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import logging
import threading
import typing
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
)
from prometheus_client.core import Counter
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.opentracing import start_active_span
from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock
if typing.TYPE_CHECKING:
from contextlib import _GeneratorContextManager
logger = logging.getLogger(__name__)
# Track how much the ratelimiter is affecting requests
rate_limit_sleep_counter = Counter(
"synapse_rate_limit_sleep",
"Number of requests slept by the rate limiter",
["rate_limiter_name"],
)
rate_limit_reject_counter = Counter(
"synapse_rate_limit_reject",
"Number of requests rejected by the rate limiter",
["rate_limiter_name"],
)
queue_wait_timer = Histogram(
"synapse_rate_limit_queue_wait_time_seconds",
"Amount of time spent waiting for the rate limiter to let our request through.",
["rate_limiter_name"],
buckets=(
0.005,
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
10.0,
20.0,
"+Inf",
),
)
_rate_limiter_instances: Set["FederationRateLimiter"] = set()
# Protects the _rate_limiter_instances set from concurrent access
_rate_limiter_instances_lock = threading.Lock()
def _get_counts_from_rate_limiter_instance(
count_func: Callable[["FederationRateLimiter"], int]
) -> Mapping[Tuple[str, ...], int]:
"""Returns a count of something (slept/rejected hosts) by (metrics_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
with _rate_limiter_instances_lock:
rate_limiter_instances = list(_rate_limiter_instances)
# Map from (metrics_name,) -> int, the number of something like slept hosts
# or rejected hosts. The key type is Tuple[str], but we leave the length
# unspecified for compatability with LaterGauge's annotations.
counts: Dict[Tuple[str, ...], int] = {}
for rate_limiter_instance in rate_limiter_instances:
# Only track metrics if they provided a `metrics_name` to
# differentiate this instance of the rate limiter.
if rate_limiter_instance.metrics_name:
key = (rate_limiter_instance.metrics_name,)
counts[key] = count_func(rate_limiter_instance)
return counts
# We track the number of affected hosts per time-period so we can
# differentiate one really noisy homeserver from a general
# ratelimit tuning problem across the federation.
LaterGauge(
"synapse_rate_limit_sleep_affected_hosts",
"Number of hosts that had requests put to sleep",
["rate_limiter_name"],
lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_sleep()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
)
),
)
LaterGauge(
"synapse_rate_limit_reject_affected_hosts",
"Number of hosts that had requests rejected",
["rate_limiter_name"],
lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_reject()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
)
),
)
class FederationRateLimiter:
"""Used to rate limit request per-host."""
def __init__(
self,
reactor: IReactorTime,
clock: Clock,
config: FederationRatelimitSettings,
metrics_name: Optional[str] = None,
):
"""
Args:
reactor
clock
config
metrics_name: The name of the rate limiter so we can differentiate it
from the rest in the metrics. If `None`, we don't track metrics
for this rate limiter.
"""
self.metrics_name = metrics_name
def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(
reactor=reactor, clock=clock, config=config, metrics_name=metrics_name
)
self.ratelimiters: DefaultDict[
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
with _rate_limiter_instances_lock:
_rate_limiter_instances.add(self)
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
Example usage:
with rate_limiter.ratelimit(origin) as wait_deferred:
yield wait_deferred
# Handle request ...
Args:
host: Origin of incoming request.
Returns:
context manager which returns a deferred.
"""
return self.ratelimiters[host].ratelimit(host)
class _PerHostRatelimiter:
def __init__(
self,
reactor: IReactorTime,
clock: Clock,
config: FederationRatelimitSettings,
metrics_name: Optional[str] = None,
):
"""
Args:
reactor
clock
config
metrics_name: The name of the rate limiter so we can differentiate it
from the rest in the metrics. If `None`, we don't track metrics
for this rate limiter.
from the rest in the metrics
"""
self.reactor = reactor
self.clock = clock
self.metrics_name = metrics_name
self.window_size = config.window_size
self.sleep_limit = config.sleep_limit
self.sleep_sec = config.sleep_delay / 1000.0
self.reject_limit = config.reject_limit
self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests: Set[object] = set()
# map from request_id object to Deferred for requests which are ready
# for processing but have been queued
self.ready_request_queue: collections.OrderedDict[
object, defer.Deferred[None]
] = collections.OrderedDict()
# request id objects for requests which are in progress
self.current_processing: Set[object] = set()
# times at which we have recently (within the last window_size ms)
# received requests.
self.request_times: List[int] = []
@contextlib.contextmanager
def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
self.host = host
request_id = object()
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
# type-checking, but we'd need Twisted >= 21.2.
ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
try:
yield ret
finally:
self._on_exit(request_id)
def should_reject(self) -> bool:
"""
Whether to reject the request if we already have too many queued up
(either sleeping or in the ready queue).
"""
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
return queue_size > self.reject_limit
def should_sleep(self) -> bool:
"""
Whether to sleep the request if we already have too many requests coming
through within the window.
"""
return len(self.request_times) > self.sleep_limit
async def _on_enter_with_tracing(self, request_id: object) -> None:
maybe_metrics_cm: ContextManager = contextlib.nullcontext()
if self.metrics_name:
maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
with start_active_span("ratelimit wait"), maybe_metrics_cm:
await self._on_enter(request_id)
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
if self.should_reject():
logger.debug("Ratelimiter(%s): rejecting request", self.host)
if self.metrics_name:
rate_limit_reject_counter.labels(self.metrics_name).inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
)
self.request_times.append(time_now)
def queue_request() -> "defer.Deferred[None]":
if len(self.current_processing) >= self.concurrent_requests:
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
"Ratelimiter(%s): queueing request (queue now %i items)",
self.host,
len(self.ready_request_queue),
)
return queue_defer
else:
return defer.succeed(None)
logger.debug(
"Ratelimit(%s) [%s]: len(self.request_times)=%d",
self.host,
id(request_id),
len(self.request_times),
)
if self.should_sleep():
logger.debug(
"Ratelimiter(%s) [%s]: sleeping request for %f sec",
self.host,
id(request_id),
self.sleep_sec,
)
if self.metrics_name:
rate_limit_sleep_counter.labels(self.metrics_name).inc()
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
logger.debug(
"Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
)
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r: object) -> object:
logger.debug(
"Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
)
self.current_processing.add(request_id)
return r
def on_err(r: object) -> object:
# XXX: why is this necessary? this is called before we start
# processing the request so why would the request be in
# current_processing?
self.current_processing.discard(request_id)
return r
def on_both(r: object) -> object:
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
# When requests complete synchronously, we will recursively start the next
# request in the queue. To avoid stack exhaustion, we defer starting the next
# request until the next reactor tick.
def start_next_request() -> None:
# We only remove the completed request from the list when we're about to
# start the next one, otherwise we can allow extra requests through.
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
_, deferred = self.ready_request_queue.popitem(last=False)
with PreserveLoggingContext():
deferred.callback(None)
except KeyError:
pass
self.reactor.callLater(0.0, start_next_request)