Add rate-limiting on registration (#4735)

* Rate-limiting for registration

* Add unit test for registration rate limiting

* Add config parameters for rate limiting on auth endpoints

* Doc

* Fix doc of rate limiting function

Co-Authored-By: babolivier <contact@brendanabolivier.com>

* Incorporate review

* Fix config parsing

* Fix linting errors

* Set default config for auth rate limiting

* Fix tests

* Add changelog

* Advance reactor instead of mocked clock

* Move parameters to registration specific config and give them more sensible default values

* Remove unused config options

* Don't mock the rate limiter un MAU tests

* Rename _register_with_store into register_with_store

* Make CI happy

* Remove unused import

* Update sample config

* Fix ratelimiting test for py2

* Add non-guest test
This commit is contained in:
Brendan Abolivier 2019-03-05 14:25:33 +00:00 committed by GitHub
parent 3887e0cd80
commit a4c3a361b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 186 additions and 54 deletions

1
changelog.d/4735.feature Normal file
View File

@ -0,0 +1 @@
Add configurable rate limiting to the /register endpoint.

View File

@ -657,6 +657,17 @@ trusted_third_party_id_servers:
# #
autocreate_auto_join_rooms: true autocreate_auto_join_rooms: true
# Number of registration requests a client can send per second.
# Defaults to 1/minute (0.17).
#
#rc_registration_requests_per_second: 0.17
# Number of registration requests a client can send before being
# throttled.
# Defaults to 3.
#
#rc_registration_request_burst_count: 3.0
## Metrics ### ## Metrics ###

View File

@ -23,12 +23,13 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
"""Can the user send a message? """Can the entity (e.g. user or IP address) perform the action?
Args: Args:
user_id: The user sending a message. key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
time_now_s: The time now. time_now_s: The time now.
msg_rate_hz: The long term number of messages a user can send in a rate_hz: The long term number of messages a user can send in a
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
@ -41,10 +42,10 @@ class Ratelimiter(object):
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get( message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None), key, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
sent_count = message_count - time_delta * msg_rate_hz sent_count = message_count - time_delta * rate_hz
if sent_count < 0: if sent_count < 0:
allowed = True allowed = True
time_start = time_now_s time_start = time_now_s
@ -56,13 +57,13 @@ class Ratelimiter(object):
message_count += 1 message_count += 1
if update: if update:
self.message_counts[user_id] = ( self.message_counts[key] = (
message_count, time_start, msg_rate_hz message_count, time_start, rate_hz
) )
if msg_rate_hz > 0: if rate_hz > 0:
time_allowed = ( time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz time_start + (message_count - burst_count + 1) / rate_hz
) )
if time_allowed < time_now_s: if time_allowed < time_now_s:
time_allowed = time_now_s time_allowed = time_now_s
@ -72,12 +73,12 @@ class Ratelimiter(object):
return allowed, time_allowed return allowed, time_allowed
def prune_message_counts(self, time_now_s): def prune_message_counts(self, time_now_s):
for user_id in list(self.message_counts.keys()): for key in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = ( message_count, time_start, rate_hz = (
self.message_counts[user_id] self.message_counts[key]
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0: if message_count - time_delta * rate_hz > 0:
break break
else: else:
del self.message_counts[user_id] del self.message_counts[key]

View File

@ -54,6 +54,13 @@ class RegistrationConfig(Config):
config.get("disable_msisdn_registration", False) config.get("disable_msisdn_registration", False)
) )
self.rc_registration_requests_per_second = config.get(
"rc_registration_requests_per_second", 0.17,
)
self.rc_registration_request_burst_count = config.get(
"rc_registration_request_burst_count", 3,
)
def default_config(self, generate_secrets=False, **kwargs): def default_config(self, generate_secrets=False, **kwargs):
if generate_secrets: if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % ( registration_shared_secret = 'registration_shared_secret: "%s"' % (
@ -140,6 +147,17 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist. # users cannot be auto-joined since they do not exist.
# #
autocreate_auto_join_rooms: true autocreate_auto_join_rooms: true
# Number of registration requests a client can send per second.
# Defaults to 1/minute (0.17).
#
#rc_registration_requests_per_second: 0.17
# Number of registration requests a client can send before being
# throttled.
# Defaults to 3.
#
#rc_registration_request_burst_count: 3.0
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View File

@ -93,9 +93,9 @@ class BaseHandler(object):
messages_per_second = self.hs.config.rc_messages_per_second messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count burst_count = self.hs.config.rc_message_burst_count
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now, user_id, time_now,
msg_rate_hz=messages_per_second, rate_hz=messages_per_second,
burst_count=burst_count, burst_count=burst_count,
update=update, update=update,
) )

View File

@ -24,6 +24,7 @@ from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
InvalidCaptchaError, InvalidCaptchaError,
LimitExceededError,
RegistrationError, RegistrationError,
SynapseError, SynapseError,
) )
@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self.identity_handler = self.hs.get_handlers().identity_handler self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_ratelimiter()
self._next_generated_user_id = None self._next_generated_user_id = None
@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
threepid=None, threepid=None,
user_type=None, user_type=None,
default_display_name=None, default_display_name=None,
address=None,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'. will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the regitration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
token = None token = None
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
yield self._register_with_store( yield self.register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name, create_profile_with_displayname=default_display_name,
admin=admin, admin=admin,
user_type=user_type, user_type=user_type,
address=address,
) )
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
if default_display_name is None: if default_display_name is None:
default_display_name = localpart default_display_name = localpart
try: try:
yield self._register_with_store( yield self.register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_displayname=default_display_name, create_profile_with_displayname=default_display_name,
address=address,
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
yield self._register_with_store( yield self.register_with_store(
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if need_register: if need_register:
yield self._register_with_store( yield self.register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
def _register_with_store(self, user_id, token=None, password_hash=None, def register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False, create_profile_with_displayname=None, admin=False,
user_type=None): user_type=None, address=None):
"""Register user in the datastore. """Register user in the datastore.
Args: Args:
@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
admin (boolean): is an admin user? admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
Returns: Returns:
Deferred Deferred
""" """
# Don't rate limit for app services
if appservice_id is None and address is not None:
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
address, time_now_s=time_now,
rate_hz=self.hs.config.rc_registration_requests_per_second,
burst_count=self.hs.config.rc_registration_request_burst_count,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
if self.hs.config.worker_app: if self.hs.config.worker_app:
return self._register_client( return self._register_client(
user_id=user_id, user_id=user_id,
@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname, create_profile_with_displayname=create_profile_with_displayname,
admin=admin, admin=admin,
user_type=user_type, user_type=user_type,
address=address,
) )
else: else:
return self.store.register( return self.store.register(

View File

@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
def __init__(self, hs): def __init__(self, hs):
super(ReplicationRegisterServlet, self).__init__(hs) super(ReplicationRegisterServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload( def _serialize_payload(
user_id, token, password_hash, was_guest, make_guest, appservice_id, user_id, token, password_hash, was_guest, make_guest, appservice_id,
create_profile_with_displayname, admin, user_type, create_profile_with_displayname, admin, user_type, address,
): ):
""" """
Args: Args:
@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin (boolean): is an admin user? admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
""" """
return { return {
"token": token, "token": token,
@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"create_profile_with_displayname": create_profile_with_displayname, "create_profile_with_displayname": create_profile_with_displayname,
"admin": admin, "admin": admin,
"user_type": user_type, "user_type": user_type,
"address": address,
} }
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_request(self, request, user_id): def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
yield self.store.register( yield self.registration_handler.register_with_store(
user_id=user_id, user_id=user_id,
token=content["token"], token=content["token"],
password_hash=content["password_hash"], password_hash=content["password_hash"],
@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
create_profile_with_displayname=content["create_profile_with_displayname"], create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"], admin=content["admin"],
user_type=content["user_type"], user_type=content["user_type"],
address=content["address"]
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -25,7 +25,12 @@ from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.api.errors import (
Codes,
LimitExceededError,
SynapseError,
UnrecognizedRequestError,
)
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_ratelimiter()
self.clock = hs.get_clock()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr, time_now_s=time_now,
rate_hz=self.hs.config.rc_registration_requests_per_second,
burst_count=self.hs.config.rc_registration_request_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
kind = b"user" kind = b"user"
if b"kind" in request.args: if b"kind" in request.args:
kind = request.args[b"kind"][0] kind = request.args[b"kind"][0]
if kind == b"guest": if kind == b"guest":
ret = yield self._do_guest_registration(body) ret = yield self._do_guest_registration(body, address=client_addr)
defer.returnValue(ret) defer.returnValue(ret)
return return
elif kind != b"user": elif kind != b"user":
@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
generate_token=False, generate_token=False,
threepid=threepid, threepid=threepid,
address=client_addr,
) )
# Necessary due to auth checks prior to the threepid being # Necessary due to auth checks prior to the threepid being
# written to the db # written to the db
@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self, params): def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled") raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register( user_id, _ = yield self.registration_handler.register(
generate_token=False, generate_token=False,
make_guest=True make_guest=True,
address=address,
) )
# we don't allow guests to specify their own device_id, because # we don't allow guests to specify their own device_id, because

View File

@ -6,34 +6,34 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):
def test_allowed(self): def test_allowed(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.can_do_action(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1 key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
) )
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.can_do_action(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1 key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
) )
self.assertFalse(allowed) self.assertFalse(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.can_do_action(
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1 key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
) )
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(20., time_allowed) self.assertEquals(20., time_allowed)
def test_pruning(self): def test_pruning(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.can_do_action(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1 key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
) )
self.assertIn("test_id_1", limiter.message_counts) self.assertIn("test_id_1", limiter.message_counts)
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.can_do_action(
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1 key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
) )
self.assertNotIn("test_id_1", limiter.message_counts) self.assertNotIn("test_id_1", limiter.message_counts)

View File

@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.can_do_action.return_value = (True, 0)
self.store = hs.get_datastore() self.store = hs.get_datastore()

View File

@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"blue", "blue",
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
hs.get_ratelimiter().send_message.return_value = (True, 0) hs.get_ratelimiter().can_do_action.return_value = (True, 0)
return hs return hs

View File

@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config.auto_join_rooms = [] config.auto_join_rooms = []
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
config=config, ratelimiter=NonCallableMock(spec_set=["send_message"]) config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View File

@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
"red", "red",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.ratelimiter = self.hs.get_ratelimiter() self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())
@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
# auth as user_id now # auth as user_id now
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
def test_send_message(self): def test_can_do_action(self):
msg_content = b'{"msgtype":"m.text","body":"hello"}' msg_content = b'{"msgtype":"m.text","body":"hello"}'
seq = iter(range(100)) seq = iter(range(100))

View File

@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"red", "red",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
) )
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources["typing"]
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View File

@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled") self.assertEquals(channel.json_body["error"], "Guest access is disabled")
def test_POST_ratelimiting_guest(self):
self.hs.config.rc_registration_request_burst_count = 5
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_POST_ratelimiting(self):
self.hs.config.rc_registration_request_burst_count = 5
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

View File

@ -17,7 +17,7 @@
import json import json
from mock import Mock, NonCallableMock from mock import Mock
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
"red", "red",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()

View File

@ -150,6 +150,8 @@ def default_config(name):
config.admin_contact = None config.admin_contact = None
config.rc_messages_per_second = 10000 config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000 config.rc_message_burst_count = 10000
config.rc_registration_request_burst_count = 3.0
config.rc_registration_requests_per_second = 0.17
config.saml2_enabled = False config.saml2_enabled = False
config.public_baseurl = None config.public_baseurl = None
config.default_identity_server = None config.default_identity_server = None