Make registration idempotent, part 2: be idempotent if the client specifies a username.

This commit is contained in:
David Baker 2016-03-16 19:36:57 +00:00
parent 48b2e853a8
commit a7daa5ae13
3 changed files with 42 additions and 6 deletions

View File

@ -160,6 +160,20 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
return sid
def set_session_data(self, session_id, key, value): def set_session_data(self, session_id, key, value):
""" """
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this

View File

@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None self._next_generated_user_id = None
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None): def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
yield run_on_reactor() yield run_on_reactor()
if urllib.quote(localpart.encode('utf-8')) != localpart: if urllib.quote(localpart.encode('utf-8')) != localpart:
@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
if assigned_user_id:
if user_id == assigned_user_id:
return
else:
raise SynapseError(
400,
"A different user ID has already been registered for this session",
)
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -122,10 +123,25 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
session_id = self.auth_handler.get_session_id(body)
logger.error("session id: %r", session_id)
registered_user_id = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
logger.error("already regged: %r", registered_user_id)
logger.error("check: %r", desired_username)
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( yield self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
) )
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -147,10 +163,6 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((401, result)) defer.returnValue((401, result))
return return
# have we already registered a user for this session
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",