Performance improvements and refactor of Ratelimiter (#7595)

While working on https://github.com/matrix-org/synapse/issues/5665 I found myself digging into the `Ratelimiter` class and seeing that it was both:

* Rather undocumented, and
* causing a *lot* of config checks

This PR attempts to refactor and comment the `Ratelimiter` class, as well as encourage config file accesses to only be done at instantiation. 

Best to be reviewed commit-by-commit.
This commit is contained in:
Andrew Morgan 2020-06-05 10:47:20 +01:00 committed by GitHub
parent c389bfb6ea
commit f4e6495b5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 325 additions and 233 deletions

1
changelog.d/7595.misc Normal file
View File

@ -0,0 +1 @@
Refactor `Ratelimiter` to limit the amount of expensive config value accesses.

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,75 +17,157 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.util import Clock
class Ratelimiter(object): class Ratelimiter(object):
""" """
Ratelimit message sending by user. Ratelimit actions marked by arbitrary keys.
Args:
clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second.
burst_count: How many actions that can be performed before being limited.
""" """
def __init__(self): def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
self.message_counts = ( self.clock = clock
OrderedDict() self.rate_hz = rate_hz
) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] self.burst_count = burst_count
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): # A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
# to a tuple representing:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
def can_do_action(
self,
key: Any,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
Args: Args:
key: The key we should use when rate limiting. Can be a user ID key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc. (when sending events), an IP address, etc.
time_now_s: The time now. rate_hz: The long term number of actions that can be performed in a second.
rate_hz: The long term number of messages a user can send in a Overrides the value set during instantiation if set.
second. burst_count: How many actions that can be performed before being limited.
burst_count: How many messages the user can send before being Overrides the value set during instantiation if set.
limited. update: Whether to count this check as performing the action
update (bool): Whether to update the message rates or not. This is _time_now_s: The current time. Optional, defaults to the current time according
useful to check if a message would be allowed to be sent before to self.clock. Only used by tests.
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A tuple containing:
time in seconds of when they can next send a message. * A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
""" """
self.prune_message_counts(time_now_s) # Override default values if set
message_count, time_start, _ignored = self.message_counts.get( time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
key, (0.0, time_now_s, None) rate_hz = rate_hz if rate_hz is not None else self.rate_hz
) burst_count = burst_count if burst_count is not None else self.burst_count
# Remove any expired entries
self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
# Check whether performing another action is allowed
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
sent_count = message_count - time_delta * rate_hz performed_count = action_count - time_delta * rate_hz
if sent_count < 0: if performed_count < 0:
# Allow, reset back to count 1
allowed = True allowed = True
time_start = time_now_s time_start = time_now_s
message_count = 1.0 action_count = 1.0
elif sent_count > burst_count - 1.0: elif performed_count > burst_count - 1.0:
# Deny, we have exceeded our burst count
allowed = False allowed = False
else: else:
# We haven't reached our limit yet
allowed = True allowed = True
message_count += 1 action_count += 1.0
if update: if update:
self.message_counts[key] = (message_count, time_start, rate_hz) self.actions[key] = (action_count, time_start, rate_hz)
if rate_hz > 0: if rate_hz > 0:
time_allowed = time_start + (message_count - burst_count + 1) / rate_hz # Find out when the count of existing actions expires
time_allowed = time_start + (action_count - burst_count + 1) / rate_hz
# Don't give back a time in the past
if time_allowed < time_now_s: if time_allowed < time_now_s:
time_allowed = time_now_s time_allowed = time_now_s
else: else:
# XXX: Why is this -1? This seems to only be used in
# self.ratelimit. I guess so that clients get a time in the past and don't
# feel afraid to try again immediately
time_allowed = -1 time_allowed = -1
return allowed, time_allowed return allowed, time_allowed
def prune_message_counts(self, time_now_s): def _prune_message_counts(self, time_now_s: int):
for key in list(self.message_counts.keys()): """Remove message count entries that have not exceeded their defined
message_count, time_start, rate_hz = self.message_counts[key] rate_hz limit
time_delta = time_now_s - time_start
if message_count - time_delta * rate_hz > 0: Args:
break time_now_s: The current time
else: """
del self.message_counts[key] # We create a copy of the key list here as the dictionary is modified during
# the loop
for key in list(self.actions.keys()):
action_count, time_start, rate_hz = self.actions[key]
# Rate limit = "seconds since we started limiting this action" * rate_hz
# If this limit has not been exceeded, wipe our record of this action
time_delta = time_now_s - time_start
if action_count - time_delta * rate_hz > 0:
continue
else:
del self.actions[key]
def ratelimit(
self,
key: Any,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
Args:
key: An arbitrary key used to classify an action
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Raises:
LimitExceededError: If an action could not be performed, along with the time in
milliseconds until the action can be performed again
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
allowed, time_allowed = self.can_do_action( allowed, time_allowed = self.can_do_action(
key, time_now_s, rate_hz, burst_count, update key,
rate_hz=rate_hz,
burst_count=burst_count,
update=update,
_time_now_s=time_now_s,
) )
if not allowed: if not allowed:

View File

@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
from ._base import Config from ._base import Config
class RateLimitConfig(object): class RateLimitConfig(object):
def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): def __init__(
self,
config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"]) self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"]) self.burst_count = config.get("burst_count", defaults["burst_count"])

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID from synapse.types import UserID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,11 +44,26 @@ class BaseHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@ -70,7 +85,6 @@ class BaseHandler(object):
Raises: Raises:
LimitExceededError if the request should be ratelimited LimitExceededError if the request should be ratelimited
""" """
time_now = self.clock.time()
user_id = requester.user.to_string() user_id = requester.user.to_string()
# The AS user itself is never rate limited. # The AS user itself is never rate limited.
@ -83,48 +97,32 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited(): if requester.app_service and not requester.app_service.is_rate_limited():
return return
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB. # Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id) override = yield self.store.get_ratelimit_for_user(user_id)
if override: if override:
# If overriden with a null Hz then ratelimiting has been entirely # If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user # disabled for the user
if not override.messages_per_second: if not override.messages_per_second:
return return
messages_per_second = override.messages_per_second messages_per_second = override.messages_per_second
burst_count = override.burst_count burst_count = override.burst_count
else:
# We default to different values if this is an admin redaction and
# the config is set
if is_admin_redaction and self.hs.config.rc_admin_redaction:
messages_per_second = self.hs.config.rc_admin_redaction.per_second
burst_count = self.hs.config.rc_admin_redaction.burst_count
else:
messages_per_second = self.hs.config.rc_message.per_second
burst_count = self.hs.config.rc_message.burst_count
if is_admin_redaction and self.hs.config.rc_admin_redaction: if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions we use a separate # If we have separate config for admin redactions, use a separate
# ratelimiter # ratelimiter as to not have user_ids clash
allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
user_id,
time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
else: else:
allowed, time_allowed = self.ratelimiter.can_do_action( # Override rate and burst count per-user
self.request_ratelimiter.ratelimit(
user_id, user_id,
time_now,
rate_hz=messages_per_second, rate_hz=messages_per_second,
burst_count=burst_count, burst_count=burst_count,
update=update, update=update,
) )
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
async def maybe_kick_guest_users(self, event, context=None): async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.

View File

@ -108,7 +108,11 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter() self._failed_uia_attempts_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock() self._clock = self.hs.get_clock()
@ -196,13 +200,7 @@ class AuthHandler(BaseHandler):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts # Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit( self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
# build a list of supported flows # build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types] flows = [[login_type] for login_type in self._supported_ui_auth_types]
@ -212,14 +210,8 @@ class AuthHandler(BaseHandler):
flows, request, request_body, clientip, description flows, request, request_body, clientip, description
) )
except LoginError: except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise). # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action( self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise raise
# find the completed login type # find the completed login type

View File

@ -362,7 +362,6 @@ class EventCreationHandler(object):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.config = hs.config self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases self.require_membership_for_aliases = hs.config.require_membership_for_aliases

View File

@ -425,14 +425,7 @@ class RegistrationHandler(BaseHandler):
if not address: if not address:
return return
time_now = self.clock.time() self.ratelimiter.ratelimit(address)
self.ratelimiter.ratelimit(
address,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
def register_with_store( def register_with_store(
self, self,

View File

@ -87,11 +87,22 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter() self._address_ratelimiter = Ratelimiter(
self._account_ratelimiter = Ratelimiter() clock=hs.get_clock(),
self._failed_attempts_ratelimiter = Ratelimiter() rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
self._failed_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -124,13 +135,7 @@ class LoginRestServlet(RestServlet):
return 200, {} return 200, {}
async def on_POST(self, request): async def on_POST(self, request):
self._address_ratelimiter.ratelimit( self._address_ratelimiter.ratelimit(request.getClientIP())
request.getClientIP(),
time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
)
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
try: try:
@ -198,13 +203,7 @@ class LoginRestServlet(RestServlet):
# We also apply account rate limiting using the 3PID as a key, as # We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID. # otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit( self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
# Check for login providers that support 3pid login types # Check for login providers that support 3pid login types
( (
@ -238,13 +237,7 @@ class LoginRestServlet(RestServlet):
# If it returned None but the 3PID was bound then we won't hit # If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit # this code path, which is fine as then the per-user ratelimit
# will kick in below. # will kick in below.
self._failed_attempts_ratelimiter.can_do_action( self._failed_attempts_ratelimiter.can_do_action((medium, address))
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id} identifier = {"type": "m.id.user", "user": user_id}
@ -263,11 +256,7 @@ class LoginRestServlet(RestServlet):
# Check if we've hit the failed ratelimit (but don't update it) # Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit( self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), qualified_user_id.lower(), update=False
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
) )
try: try:
@ -279,13 +268,7 @@ class LoginRestServlet(RestServlet):
# limiter. Using `can_do_action` avoids us raising a ratelimit # limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting # exception and masking the LoginError. The actual ratelimiting
# should have happened above. # should have happened above.
self._failed_attempts_ratelimiter.can_do_action( self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise raise
result = await self._complete_login( result = await self._complete_login(
@ -318,13 +301,7 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in # Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't # too often. This happens here rather than before as we don't
# necessarily know the user before now. # necessarily know the user before now.
self._account_ratelimiter.ratelimit( self._account_ratelimiter.ratelimit(user_id.lower())
user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
)
if create_non_existent_users: if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id) canonical_uid = await self.auth_handler.check_user_exists(user_id)

View File

@ -26,7 +26,6 @@ import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
LimitExceededError,
SynapseError, SynapseError,
ThreepidValidationError, ThreepidValidationError,
UnrecognizedRequestError, UnrecognizedRequestError,
@ -396,20 +395,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP() client_addr = request.getClientIP()
time_now = self.clock.time() self.ratelimiter.ratelimit(client_addr, update=False)
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
kind = b"user" kind = b"user"
if b"kind" in request.args: if b"kind" in request.args:

View File

@ -242,9 +242,12 @@ class HomeServer(object):
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.distributor = Distributor() self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter(
self.registration_ratelimiter = Ratelimiter() clock=self.clock,
rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count,
)
self.datastores = None self.datastores = None
@ -314,15 +317,9 @@ class HomeServer(object):
def get_distributor(self): def get_distributor(self):
return self.distributor return self.distributor
def get_ratelimiter(self): def get_registration_ratelimiter(self) -> Ratelimiter:
return self.ratelimiter
def get_registration_ratelimiter(self):
return self.registration_ratelimiter return self.registration_ratelimiter
def get_admin_redaction_ratelimiter(self):
return self.admin_redaction_ratelimiter
def build_federation_client(self): def build_federation_client(self):
return FederationClient(self) return FederationClient(self)

View File

@ -43,7 +43,7 @@ class FederationRateLimiter(object):
self.ratelimiters = collections.defaultdict(new_limiter) self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host): def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host """Used to ratelimit an incoming request from a given host
Example usage: Example usage:

View File

@ -1,39 +1,97 @@
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from tests import unittest from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):
def test_allowed(self): def test_allowed_via_can_do_action(self):
limiter = Ratelimiter() limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_do_action( allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed) self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action( allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
)
self.assertFalse(allowed) self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed) self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action( allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed) self.assertEquals(20.0, time_allowed)
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
# Shouldn't raise
limiter.ratelimit(key="test_id", _time_now_s=0)
# Should raise
with self.assertRaises(LimitExceededError) as context:
limiter.ratelimit(key="test_id", _time_now_s=5)
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
limiter.ratelimit(key="test_id", _time_now_s=10)
def test_allowed_via_can_do_action_and_overriding_parameters(self):
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
# First attempt should be allowed
allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
allowed, time_allowed = limiter.can_do_action(
("test_id",), _time_now_s=1, rate_hz=10.0
)
self.assertTrue(allowed)
self.assertEqual(1.1, time_allowed)
# Similarly if we allow a burst of 10 actions
allowed, time_allowed = limiter.can_do_action(
("test_id",), _time_now_s=1, burst_count=10
)
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
def test_allowed_via_ratelimit_and_overriding_parameters(self):
"""Test that we can override options of the ratelimit method that would otherwise
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
# First attempt should be allowed
limiter.ratelimit(key=("test_id",), _time_now_s=0)
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
limiter.ratelimit(key=("test_id",), _time_now_s=1)
self.assertEqual(context.exception.retry_after_ms, 9000)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
# Similarly if we allow a burst of 10 actions
limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
def test_pruning(self): def test_pruning(self):
limiter = Ratelimiter() limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_do_action( limiter.can_do_action(key="test_id_1", _time_now_s=0)
key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
)
self.assertIn("test_id_1", limiter.message_counts) self.assertIn("test_id_1", limiter.actions)
allowed, time_allowed = limiter.can_do_action( limiter.can_do_action(key="test_id_2", _time_now_s=10)
key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
)
self.assertNotIn("test_id_1", limiter.message_counts) self.assertNotIn("test_id_1", limiter.actions)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from mock import Mock, NonCallableMock from mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test") self.frank = UserID.from_string("@1234ABCD:test")

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock, NonCallableMock from mock import Mock
from tests.replication._base import BaseStreamTestCase from tests.replication._base import BaseStreamTestCase
@ -21,12 +21,7 @@ from tests.replication._base import BaseStreamTestCase
class BaseSlavedStoreTestCase(BaseStreamTestCase): class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(federation_client=Mock())
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
hs.get_ratelimiter().can_do_action.return_value = (True, 0)
return hs return hs

View File

@ -15,7 +15,7 @@
""" Tests REST events for /events paths.""" """ Tests REST events for /events paths."""
from mock import Mock, NonCallableMock from mock import Mock
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1 import events, login, room
@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config["enable_registration"] = True config["enable_registration"] = True
config["auto_join_rooms"] = [] config["auto_join_rooms"] = []
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(config=config)
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View File

@ -29,7 +29,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
@ -38,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
@override_config(
{
"rc_login": {
"address": {"per_second": 0.17, "burst_count": 5},
# Prevent the account login ratelimiter from raising first
#
# This is normally covered by the default test homeserver config
# which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well
"account": {"per_second": 10000, "burst_count": 10000},
}
}
)
def test_POST_ratelimiting_per_address(self): def test_POST_ratelimiting_per_address(self):
self.hs.config.rc_login_address.burst_count = 5
self.hs.config.rc_login_address.per_second = 0.17
# Create different users so we're sure not to be bothered by the per-user # Create different users so we're sure not to be bothered by the per-user
# ratelimiter. # ratelimiter.
for i in range(0, 6): for i in range(0, 6):
@ -80,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
"account": {"per_second": 0.17, "burst_count": 5},
# Prevent the address login ratelimiter from raising first
#
# This is normally covered by the default test homeserver config
# which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000},
}
}
)
def test_POST_ratelimiting_per_account(self): def test_POST_ratelimiting_per_account(self):
self.hs.config.rc_login_account.burst_count = 5
self.hs.config.rc_login_account.per_second = 0.17
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):
@ -119,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
# Prevent the address login ratelimiter from raising first
#
# This is normally covered by the default test homeserver config
# which sets these values to 10000, but as we're overriding the entire
# rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 0.17, "burst_count": 5},
}
}
)
def test_POST_ratelimiting_per_account_failed_attempts(self): def test_POST_ratelimiting_per_account_failed_attempts(self):
self.hs.config.rc_login_failed_attempts.burst_count = 5
self.hs.config.rc_login_failed_attempts.per_second = 0.17
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):

View File

@ -20,7 +20,7 @@
import json import json
from mock import Mock, NonCallableMock from mock import Mock
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver( self.hs = self.setup_test_homeserver(
"red", "red", http_client=None, federation_client=Mock(),
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())

View File

@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
from mock import Mock, NonCallableMock from mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red", http_client=None, federation_client=Mock(),
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources["typing"]
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):

View File

@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest from tests import unittest
from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase): class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@ -146,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled") self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self): def test_POST_ratelimiting_guest(self):
self.hs.config.rc_registration.burst_count = 5
self.hs.config.rc_registration.per_second = 0.17
for i in range(0, 6): for i in range(0, 6):
url = self.url + b"?kind=guest" url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}") request, channel = self.make_request(b"POST", url, b"{}")
@ -168,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self): def test_POST_ratelimiting(self):
self.hs.config.rc_registration.burst_count = 5
self.hs.config.rc_registration.per_second = 0.17
for i in range(0, 6): for i in range(0, 6):
params = { params = {
"username": "kermit" + str(i), "username": "kermit" + str(i),