mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-03 18:50:52 -05:00
Add rate-limiting on registration (#4735)
* Rate-limiting for registration * Add unit test for registration rate limiting * Add config parameters for rate limiting on auth endpoints * Doc * Fix doc of rate limiting function Co-Authored-By: babolivier <contact@brendanabolivier.com> * Incorporate review * Fix config parsing * Fix linting errors * Set default config for auth rate limiting * Fix tests * Add changelog * Advance reactor instead of mocked clock * Move parameters to registration specific config and give them more sensible default values * Remove unused config options * Don't mock the rate limiter un MAU tests * Rename _register_with_store into register_with_store * Make CI happy * Remove unused import * Update sample config * Fix ratelimiting test for py2 * Add non-guest test
This commit is contained in:
parent
3887e0cd80
commit
a4c3a361b7
1
changelog.d/4735.feature
Normal file
1
changelog.d/4735.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add configurable rate limiting to the /register endpoint.
|
@ -657,6 +657,17 @@ trusted_third_party_id_servers:
|
|||||||
#
|
#
|
||||||
autocreate_auto_join_rooms: true
|
autocreate_auto_join_rooms: true
|
||||||
|
|
||||||
|
# Number of registration requests a client can send per second.
|
||||||
|
# Defaults to 1/minute (0.17).
|
||||||
|
#
|
||||||
|
#rc_registration_requests_per_second: 0.17
|
||||||
|
|
||||||
|
# Number of registration requests a client can send before being
|
||||||
|
# throttled.
|
||||||
|
# Defaults to 3.
|
||||||
|
#
|
||||||
|
#rc_registration_request_burst_count: 3.0
|
||||||
|
|
||||||
|
|
||||||
## Metrics ###
|
## Metrics ###
|
||||||
|
|
||||||
|
@ -23,12 +23,13 @@ class Ratelimiter(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.message_counts = collections.OrderedDict()
|
self.message_counts = collections.OrderedDict()
|
||||||
|
|
||||||
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
|
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
|
||||||
"""Can the user send a message?
|
"""Can the entity (e.g. user or IP address) perform the action?
|
||||||
Args:
|
Args:
|
||||||
user_id: The user sending a message.
|
key: The key we should use when rate limiting. Can be a user ID
|
||||||
|
(when sending events), an IP address, etc.
|
||||||
time_now_s: The time now.
|
time_now_s: The time now.
|
||||||
msg_rate_hz: The long term number of messages a user can send in a
|
rate_hz: The long term number of messages a user can send in a
|
||||||
second.
|
second.
|
||||||
burst_count: How many messages the user can send before being
|
burst_count: How many messages the user can send before being
|
||||||
limited.
|
limited.
|
||||||
@ -41,10 +42,10 @@ class Ratelimiter(object):
|
|||||||
"""
|
"""
|
||||||
self.prune_message_counts(time_now_s)
|
self.prune_message_counts(time_now_s)
|
||||||
message_count, time_start, _ignored = self.message_counts.get(
|
message_count, time_start, _ignored = self.message_counts.get(
|
||||||
user_id, (0., time_now_s, None),
|
key, (0., time_now_s, None),
|
||||||
)
|
)
|
||||||
time_delta = time_now_s - time_start
|
time_delta = time_now_s - time_start
|
||||||
sent_count = message_count - time_delta * msg_rate_hz
|
sent_count = message_count - time_delta * rate_hz
|
||||||
if sent_count < 0:
|
if sent_count < 0:
|
||||||
allowed = True
|
allowed = True
|
||||||
time_start = time_now_s
|
time_start = time_now_s
|
||||||
@ -56,13 +57,13 @@ class Ratelimiter(object):
|
|||||||
message_count += 1
|
message_count += 1
|
||||||
|
|
||||||
if update:
|
if update:
|
||||||
self.message_counts[user_id] = (
|
self.message_counts[key] = (
|
||||||
message_count, time_start, msg_rate_hz
|
message_count, time_start, rate_hz
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg_rate_hz > 0:
|
if rate_hz > 0:
|
||||||
time_allowed = (
|
time_allowed = (
|
||||||
time_start + (message_count - burst_count + 1) / msg_rate_hz
|
time_start + (message_count - burst_count + 1) / rate_hz
|
||||||
)
|
)
|
||||||
if time_allowed < time_now_s:
|
if time_allowed < time_now_s:
|
||||||
time_allowed = time_now_s
|
time_allowed = time_now_s
|
||||||
@ -72,12 +73,12 @@ class Ratelimiter(object):
|
|||||||
return allowed, time_allowed
|
return allowed, time_allowed
|
||||||
|
|
||||||
def prune_message_counts(self, time_now_s):
|
def prune_message_counts(self, time_now_s):
|
||||||
for user_id in list(self.message_counts.keys()):
|
for key in list(self.message_counts.keys()):
|
||||||
message_count, time_start, msg_rate_hz = (
|
message_count, time_start, rate_hz = (
|
||||||
self.message_counts[user_id]
|
self.message_counts[key]
|
||||||
)
|
)
|
||||||
time_delta = time_now_s - time_start
|
time_delta = time_now_s - time_start
|
||||||
if message_count - time_delta * msg_rate_hz > 0:
|
if message_count - time_delta * rate_hz > 0:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
del self.message_counts[user_id]
|
del self.message_counts[key]
|
||||||
|
@ -54,6 +54,13 @@ class RegistrationConfig(Config):
|
|||||||
config.get("disable_msisdn_registration", False)
|
config.get("disable_msisdn_registration", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.rc_registration_requests_per_second = config.get(
|
||||||
|
"rc_registration_requests_per_second", 0.17,
|
||||||
|
)
|
||||||
|
self.rc_registration_request_burst_count = config.get(
|
||||||
|
"rc_registration_request_burst_count", 3,
|
||||||
|
)
|
||||||
|
|
||||||
def default_config(self, generate_secrets=False, **kwargs):
|
def default_config(self, generate_secrets=False, **kwargs):
|
||||||
if generate_secrets:
|
if generate_secrets:
|
||||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||||
@ -140,6 +147,17 @@ class RegistrationConfig(Config):
|
|||||||
# users cannot be auto-joined since they do not exist.
|
# users cannot be auto-joined since they do not exist.
|
||||||
#
|
#
|
||||||
autocreate_auto_join_rooms: true
|
autocreate_auto_join_rooms: true
|
||||||
|
|
||||||
|
# Number of registration requests a client can send per second.
|
||||||
|
# Defaults to 1/minute (0.17).
|
||||||
|
#
|
||||||
|
#rc_registration_requests_per_second: 0.17
|
||||||
|
|
||||||
|
# Number of registration requests a client can send before being
|
||||||
|
# throttled.
|
||||||
|
# Defaults to 3.
|
||||||
|
#
|
||||||
|
#rc_registration_request_burst_count: 3.0
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
@ -93,9 +93,9 @@ class BaseHandler(object):
|
|||||||
messages_per_second = self.hs.config.rc_messages_per_second
|
messages_per_second = self.hs.config.rc_messages_per_second
|
||||||
burst_count = self.hs.config.rc_message_burst_count
|
burst_count = self.hs.config.rc_message_burst_count
|
||||||
|
|
||||||
allowed, time_allowed = self.ratelimiter.send_message(
|
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||||
user_id, time_now,
|
user_id, time_now,
|
||||||
msg_rate_hz=messages_per_second,
|
rate_hz=messages_per_second,
|
||||||
burst_count=burst_count,
|
burst_count=burst_count,
|
||||||
update=update,
|
update=update,
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ from synapse.api.errors import (
|
|||||||
AuthError,
|
AuthError,
|
||||||
Codes,
|
Codes,
|
||||||
InvalidCaptchaError,
|
InvalidCaptchaError,
|
||||||
|
LimitExceededError,
|
||||||
RegistrationError,
|
RegistrationError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
self.identity_handler = self.hs.get_handlers().identity_handler
|
self.identity_handler = self.hs.get_handlers().identity_handler
|
||||||
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
|
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
|
|
||||||
@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
threepid=None,
|
threepid=None,
|
||||||
user_type=None,
|
user_type=None,
|
||||||
default_display_name=None,
|
default_display_name=None,
|
||||||
|
address=None,
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
default_display_name (unicode|None): if set, the new user's displayname
|
default_display_name (unicode|None): if set, the new user's displayname
|
||||||
will be set to this. Defaults to 'localpart'.
|
will be set to this. Defaults to 'localpart'.
|
||||||
|
address (str|None): the IP address used to perform the regitration.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
token = None
|
token = None
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.macaroon_gen.generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
yield self._register_with_store(
|
yield self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
create_profile_with_displayname=default_display_name,
|
create_profile_with_displayname=default_display_name,
|
||||||
admin=admin,
|
admin=admin,
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
|
address=address,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
|
|||||||
if default_display_name is None:
|
if default_display_name is None:
|
||||||
default_display_name = localpart
|
default_display_name = localpart
|
||||||
try:
|
try:
|
||||||
yield self._register_with_store(
|
yield self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
make_guest=make_guest,
|
make_guest=make_guest,
|
||||||
create_profile_with_displayname=default_display_name,
|
create_profile_with_displayname=default_display_name,
|
||||||
|
address=address,
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
# if user id is taken, just generate another
|
# if user id is taken, just generate another
|
||||||
@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
user_id, allowed_appservice=service
|
user_id, allowed_appservice=service
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._register_with_store(
|
yield self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash="",
|
password_hash="",
|
||||||
appservice_id=service_id,
|
appservice_id=service_id,
|
||||||
@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
token = self.macaroon_gen.generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self._register_with_store(
|
yield self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
|
|||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _register_with_store(self, user_id, token=None, password_hash=None,
|
def register_with_store(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_displayname=None, admin=False,
|
create_profile_with_displayname=None, admin=False,
|
||||||
user_type=None):
|
user_type=None, address=None):
|
||||||
"""Register user in the datastore.
|
"""Register user in the datastore.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
|
|||||||
admin (boolean): is an admin user?
|
admin (boolean): is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type (str|None): type of user. One of the values from
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
|
address (str|None): the IP address used to perform the regitration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
|
# Don't rate limit for app services
|
||||||
|
if appservice_id is None and address is not None:
|
||||||
|
time_now = self.clock.time()
|
||||||
|
|
||||||
|
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||||
|
address, time_now_s=time_now,
|
||||||
|
rate_hz=self.hs.config.rc_registration_requests_per_second,
|
||||||
|
burst_count=self.hs.config.rc_registration_request_burst_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise LimitExceededError(
|
||||||
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
|
)
|
||||||
|
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
return self._register_client(
|
return self._register_client(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
create_profile_with_displayname=create_profile_with_displayname,
|
create_profile_with_displayname=create_profile_with_displayname,
|
||||||
admin=admin,
|
admin=admin,
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
|
address=address,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.store.register(
|
return self.store.register(
|
||||||
|
@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReplicationRegisterServlet, self).__init__(hs)
|
super(ReplicationRegisterServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(
|
def _serialize_payload(
|
||||||
user_id, token, password_hash, was_guest, make_guest, appservice_id,
|
user_id, token, password_hash, was_guest, make_guest, appservice_id,
|
||||||
create_profile_with_displayname, admin, user_type,
|
create_profile_with_displayname, admin, user_type, address,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
admin (boolean): is an admin user?
|
admin (boolean): is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type (str|None): type of user. One of the values from
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
|
address (str|None): the IP address used to perform the regitration.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
"create_profile_with_displayname": create_profile_with_displayname,
|
"create_profile_with_displayname": create_profile_with_displayname,
|
||||||
"admin": admin,
|
"admin": admin,
|
||||||
"user_type": user_type,
|
"user_type": user_type,
|
||||||
|
"address": address,
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_request(self, request, user_id):
|
def _handle_request(self, request, user_id):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
yield self.store.register(
|
yield self.registration_handler.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=content["token"],
|
token=content["token"],
|
||||||
password_hash=content["password_hash"],
|
password_hash=content["password_hash"],
|
||||||
@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
create_profile_with_displayname=content["create_profile_with_displayname"],
|
create_profile_with_displayname=content["create_profile_with_displayname"],
|
||||||
admin=content["admin"],
|
admin=content["admin"],
|
||||||
user_type=content["user_type"],
|
user_type=content["user_type"],
|
||||||
|
address=content["address"]
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
@ -25,7 +25,12 @@ from twisted.internet import defer
|
|||||||
import synapse
|
import synapse
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
from synapse.api.errors import (
|
||||||
|
Codes,
|
||||||
|
LimitExceededError,
|
||||||
|
SynapseError,
|
||||||
|
UnrecognizedRequestError,
|
||||||
|
)
|
||||||
from synapse.config.server import is_threepid_reserved
|
from synapse.config.server import is_threepid_reserved
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
|
|||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@interactive_auth_handler
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
client_addr = request.getClientIP()
|
||||||
|
|
||||||
|
time_now = self.clock.time()
|
||||||
|
|
||||||
|
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||||
|
client_addr, time_now_s=time_now,
|
||||||
|
rate_hz=self.hs.config.rc_registration_requests_per_second,
|
||||||
|
burst_count=self.hs.config.rc_registration_request_burst_count,
|
||||||
|
update=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise LimitExceededError(
|
||||||
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
|
)
|
||||||
|
|
||||||
kind = b"user"
|
kind = b"user"
|
||||||
if b"kind" in request.args:
|
if b"kind" in request.args:
|
||||||
kind = request.args[b"kind"][0]
|
kind = request.args[b"kind"][0]
|
||||||
|
|
||||||
if kind == b"guest":
|
if kind == b"guest":
|
||||||
ret = yield self._do_guest_registration(body)
|
ret = yield self._do_guest_registration(body, address=client_addr)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != b"user":
|
elif kind != b"user":
|
||||||
@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
threepid=threepid,
|
threepid=threepid,
|
||||||
|
address=client_addr,
|
||||||
)
|
)
|
||||||
# Necessary due to auth checks prior to the threepid being
|
# Necessary due to auth checks prior to the threepid being
|
||||||
# written to the db
|
# written to the db
|
||||||
@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
|
|||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self, params):
|
def _do_guest_registration(self, params, address=None):
|
||||||
if not self.hs.config.allow_guest_access:
|
if not self.hs.config.allow_guest_access:
|
||||||
raise SynapseError(403, "Guest access is disabled")
|
raise SynapseError(403, "Guest access is disabled")
|
||||||
user_id, _ = yield self.registration_handler.register(
|
user_id, _ = yield self.registration_handler.register(
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
make_guest=True
|
make_guest=True,
|
||||||
|
address=address,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we don't allow guests to specify their own device_id, because
|
# we don't allow guests to specify their own device_id, because
|
||||||
|
@ -6,34 +6,34 @@ from tests import unittest
|
|||||||
class TestRatelimiter(unittest.TestCase):
|
class TestRatelimiter(unittest.TestCase):
|
||||||
def test_allowed(self):
|
def test_allowed(self):
|
||||||
limiter = Ratelimiter()
|
limiter = Ratelimiter()
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10., time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10., time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20., time_allowed)
|
self.assertEquals(20., time_allowed)
|
||||||
|
|
||||||
def test_pruning(self):
|
def test_pruning(self):
|
||||||
limiter = Ratelimiter()
|
limiter = Ratelimiter()
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIn("test_id_1", limiter.message_counts)
|
self.assertIn("test_id_1", limiter.message_counts)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.send_message(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
|
key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertNotIn("test_id_1", limiter.message_counts)
|
self.assertNotIn("test_id_1", limiter.message_counts)
|
||||||
|
@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
federation_client=self.mock_federation,
|
federation_client=self.mock_federation,
|
||||||
federation_server=Mock(),
|
federation_server=Mock(),
|
||||||
federation_registry=self.mock_registry,
|
federation_registry=self.mock_registry,
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"blue",
|
"blue",
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.get_ratelimiter().send_message.return_value = (True, 0)
|
hs.get_ratelimiter().can_do_action.return_value = (True, 0)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
|||||||
config.auto_join_rooms = []
|
config.auto_join_rooms = []
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
|
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
|
@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
|
|||||||
"red",
|
"red",
|
||||||
http_client=None,
|
http_client=None,
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||||
)
|
)
|
||||||
self.ratelimiter = self.hs.get_ratelimiter()
|
self.ratelimiter = self.hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||||
|
|
||||||
self.hs.get_federation_handler = Mock(return_value=Mock())
|
self.hs.get_federation_handler = Mock(return_value=Mock())
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||||||
# auth as user_id now
|
# auth as user_id now
|
||||||
self.helper.auth_user_id = self.user_id
|
self.helper.auth_user_id = self.user_id
|
||||||
|
|
||||||
def test_send_message(self):
|
def test_can_do_action(self):
|
||||||
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
||||||
|
|
||||||
seq = iter(range(100))
|
seq = iter(range(100))
|
||||||
|
@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
"red",
|
"red",
|
||||||
http_client=None,
|
http_client=None,
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.event_source = hs.get_event_sources().sources["typing"]
|
self.event_source = hs.get_event_sources().sources["typing"]
|
||||||
|
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
|
@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
||||||
|
|
||||||
|
def test_POST_ratelimiting_guest(self):
|
||||||
|
self.hs.config.rc_registration_request_burst_count = 5
|
||||||
|
|
||||||
|
for i in range(0, 6):
|
||||||
|
url = self.url + b"?kind=guest"
|
||||||
|
request, channel = self.make_request(b"POST", url, b"{}")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
if i == 5:
|
||||||
|
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||||
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||||
|
else:
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
self.reactor.advance(retry_after_ms / 1000.)
|
||||||
|
|
||||||
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
def test_POST_ratelimiting(self):
|
||||||
|
self.hs.config.rc_registration_request_burst_count = 5
|
||||||
|
|
||||||
|
for i in range(0, 6):
|
||||||
|
params = {
|
||||||
|
"username": "kermit" + str(i),
|
||||||
|
"password": "monkey",
|
||||||
|
"device_id": "frogfone",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
}
|
||||||
|
request_data = json.dumps(params)
|
||||||
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
if i == 5:
|
||||||
|
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||||
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||||
|
else:
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
self.reactor.advance(retry_after_ms / 1000.)
|
||||||
|
|
||||||
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
|||||||
"red",
|
"red",
|
||||||
http_client=None,
|
http_client=None,
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
|
@ -150,6 +150,8 @@ def default_config(name):
|
|||||||
config.admin_contact = None
|
config.admin_contact = None
|
||||||
config.rc_messages_per_second = 10000
|
config.rc_messages_per_second = 10000
|
||||||
config.rc_message_burst_count = 10000
|
config.rc_message_burst_count = 10000
|
||||||
|
config.rc_registration_request_burst_count = 3.0
|
||||||
|
config.rc_registration_requests_per_second = 0.17
|
||||||
config.saml2_enabled = False
|
config.saml2_enabled = False
|
||||||
config.public_baseurl = None
|
config.public_baseurl = None
|
||||||
config.default_identity_server = None
|
config.default_identity_server = None
|
||||||
|
Loading…
Reference in New Issue
Block a user