diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index 6800ac46c..7028ca694 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -21,7 +21,7 @@ support HTTPS), however individual pairings of servers may decide to communicate over a different (albeit still reliable) protocol. """ -from .server import TransportLayerServer +from .server import TransportLayerServer, FederationRateLimiter from .client import TransportLayerClient @@ -55,8 +55,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient): send requests """ self.keyring = homeserver.get_keyring() + self.clock = homeserver.get_clock() self.server_name = server_name self.server = server self.client = client self.request_handler = None self.received_handler = None + + self.ratelimiter = FederationRateLimiter( + self.clock, + window_size=10000, + sleep_limit=10, + sleep_msec=500, + reject_limit=50, + concurrent_requests=3, + ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2ffb37aa1..a9e625f12 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -16,9 +16,11 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, SynapseError, LimitExceededError +from synapse.util.async import sleep from synapse.util.logutils import log_function +import collections import logging import simplejson as json import re @@ -27,6 +29,163 @@ import re logger = logging.getLogger(__name__) +class FederationRateLimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.ratelimiters = {} + + def ratelimit(self, host): + return self.ratelimiters.setdefault( + host, + PerHostRatelimiter( + clock=self.clock, + window_size=self.window_size, + sleep_limit=self.sleep_limit, + sleep_msec=self.sleep_msec, + reject_limit=self.reject_limit, + concurrent_requests=self.concurrent_requests, + ) + ).ratelimit() + + +class PerHostRatelimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.sleeping_requests = set() + self.ready_request_queue = collections.OrderedDict() + self.current_processing = set() + self.request_times = [] + + def is_empty(self): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + return not ( + self.ready_request_queue + or self.sleeping_requests + or self.current_processing + or self.request_times + ) + + def ratelimit(self): + request_id = object() + + def on_enter(): + return self._on_enter(request_id) + + def on_exit(exc_type, exc_val, exc_tb): + return self._on_exit(request_id) + + return ContextManagerFunction(on_enter, on_exit) + + def _on_enter(self, request_id): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) + if queue_size > self.reject_limit: + raise LimitExceededError( + retry_after_ms=int( + self.window_size / self.sleep_limit + ), + ) + + self.request_times.append(time_now) + + def queue_request(): + if len(self.current_processing) > self.concurrent_requests: + logger.debug("Ratelimit [%s]: Queue req", id(request_id)) + queue_defer = defer.Deferred() + self.ready_request_queue[request_id] = queue_defer + return queue_defer + else: + return defer.succeed(None) + + logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times)) + logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times)) + + if len(self.request_times) > self.sleep_limit: + logger.debug("Ratelimit [%s]: sleeping req", id(request_id)) + ret_defer = sleep(self.sleep_msec/1000.0) + + self.sleeping_requests.add(request_id) + + def on_wait_finished(_): + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) + self.sleeping_requests.discard(request_id) + queue_defer = queue_request() + return queue_defer + + ret_defer.addBoth(on_wait_finished) + else: + ret_defer = queue_request() + + def on_start(r): + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) + self.current_processing.add(request_id) + return r + + def on_err(r): + self.current_processing.discard(request_id) + return r + + def on_both(r): + # Ensure that we've properly cleaned up. + self.sleeping_requests.discard(request_id) + self.ready_request_queue.pop(request_id, None) + return r + + ret_defer.addCallbacks(on_start, on_err) + ret_defer.addBoth(on_both) + return ret_defer + + def _on_exit(self, request_id): + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) + self.current_processing.discard(request_id) + try: + request_id, deferred = self.ready_request_queue.popitem() + self.current_processing.add(request_id) + deferred.callback(None) + except KeyError: + pass + + +class ContextManagerFunction(object): + def __init__(self, on_enter, on_exit): + self.on_enter = on_enter + self.on_exit = on_exit + + def __enter__(self): + if self.on_enter: + return self.on_enter() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.on_exit: + return self.on_exit(exc_type, exc_val, exc_tb) + + class TransportLayerServer(object): """Handles incoming federation HTTP requests""" @@ -98,15 +257,23 @@ class TransportLayerServer(object): def new_handler(request, *args, **kwargs): try: (origin, content) = yield self._authenticate_request(request) - response = yield handler( - origin, content, request.args, *args, **kwargs - ) + with self.ratelimiter.ratelimit(origin) as d: + yield d + response = yield handler( + origin, content, request.args, *args, **kwargs + ) except: logger.exception("_authenticate_request failed") raise defer.returnValue(response) return new_handler + def rate_limit_origin(self, handler): + def new_handler(origin, *args, **kwargs): + response = yield handler(origin, *args, **kwargs) + defer.returnValue(response) + return new_handler() + @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data.