Make numeric user_id checker start at @0, and don't ratelimit on checking (#6338)

This commit is contained in:
Andrew Morgan 2019-11-06 17:21:20 +00:00 committed by GitHub
commit a6ebef1bfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 21 deletions

1
changelog.d/6338.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent the server taking a long time to start up when guest registration is enabled.

View File

@ -24,7 +24,6 @@ from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
LimitExceededError,
RegistrationError,
SynapseError,
)
@ -168,6 +167,7 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
yield self.check_registration_ratelimit(address)
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
@ -217,8 +217,13 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
user = None
while not user:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@ -233,10 +238,14 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name,
address=address,
)
# Successfully registered
break
except SynapseError:
# if user id is taken, just generate another
user = None
user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
@ -414,6 +423,29 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
address (str|None): the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
LimitExceededError: If the rate limit has been exceeded.
"""
if not address:
return
time_now = self.clock.time()
self.ratelimiter.ratelimit(
address,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
def register_with_store(
self,
user_id,
@ -446,22 +478,6 @@ class RegistrationHandler(BaseHandler):
Returns:
Deferred
"""
# Don't rate limit for app services
if appservice_id is None and address is not None:
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
address,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,

View File

@ -75,6 +75,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
self.registration_handler.check_registration_ratelimit(content["address"])
await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],

View File

@ -488,14 +488,14 @@ class RegistrationWorkerStore(SQLBaseStore):
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
avoid the case of ID 1000 being pre-allocated and starting at 1001 while
0-999 are available.
"""
def _find_next_generated_user_id(txn):
# We bound between '@1' and '@a' to avoid pulling the entire table
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")