Merge pull request #650 from matrix-org/dbkr/register_idempotent_with_username

Make registration idempotent, part 2
This commit is contained in:
David Baker 2016-03-17 14:34:08 +00:00
commit 384ee6eafb
3 changed files with 38 additions and 6 deletions

View File

@ -160,6 +160,20 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
return sid
def set_session_data(self, session_id, key, value): def set_session_data(self, session_id, key, value):
""" """
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this

View File

@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None): def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
yield run_on_reactor() yield run_on_reactor()
if urllib.quote(localpart.encode('utf-8')) != localpart: if urllib.quote(localpart.encode('utf-8')) != localpart:
@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
if assigned_user_id:
if user_id == assigned_user_id:
return
else:
raise SynapseError(
400,
"A different user ID has already been registered for this session",
)
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)

View File

@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( yield self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
) )
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -147,10 +159,6 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((401, result)) defer.returnValue((401, result))
return return
# have we already registered a user for this session
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",