mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-24 20:11:07 -05:00
Merge pull request #649 from matrix-org/dbkr/idempotent_registration
Make registration idempotent
This commit is contained in:
commit
48b2e853a8
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
|
|||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
clientip (str): The IP address of the client.
|
clientip (str): The IP address of the client.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (authed, dict, dict) where authed is true if the client
|
A tuple of (authed, dict, dict, session_id) where authed is true if
|
||||||
has successfully completed an auth flow. If it is true, the first
|
the client has successfully completed an auth flow. If it is true
|
||||||
dict contains the authenticated credentials of each stage.
|
the first dict contains the authenticated credentials of each stage.
|
||||||
|
|
||||||
If authed is false, the first dictionary is the server response to
|
If authed is false, the first dictionary is the server response to
|
||||||
the login request and should be passed back to the client.
|
the login request and should be passed back to the client.
|
||||||
|
|
||||||
In either case, the second dict contains the parameters for this
|
In either case, the second dict contains the parameters for this
|
||||||
request (which may have been given only in a previous call).
|
request (which may have been given only in a previous call).
|
||||||
|
|
||||||
|
session_id is the ID of this session, either passed in by the client
|
||||||
|
or assigned by the call to check_auth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authdict = None
|
authdict = None
|
||||||
@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(False, self._auth_dict_for_flows(flows, session), clientdict)
|
(
|
||||||
|
False, self._auth_dict_for_flows(flows, session),
|
||||||
|
clientdict, session['id']
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'creds' not in session:
|
if 'creds' not in session:
|
||||||
@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
|
|||||||
for f in flows:
|
for f in flows:
|
||||||
if len(set(f) - set(creds.keys())) == 0:
|
if len(set(f) - set(creds.keys())) == 0:
|
||||||
logger.info("Auth completed with creds: %r", creds)
|
logger.info("Auth completed with creds: %r", creds)
|
||||||
self._remove_session(session)
|
defer.returnValue((True, creds, clientdict, session['id']))
|
||||||
defer.returnValue((True, creds, clientdict))
|
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
ret['completed'] = creds.keys()
|
ret['completed'] = creds.keys()
|
||||||
defer.returnValue((False, ret, clientdict))
|
defer.returnValue((False, ret, clientdict, session['id']))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||||
@ -154,6 +160,29 @@ class AuthHandler(BaseHandler):
|
|||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
def set_session_data(self, session_id, key, value):
|
||||||
|
"""
|
||||||
|
Store a key-value pair into the sessions data associated with this
|
||||||
|
request. This data is stored server-side and cannot be modified by
|
||||||
|
the client.
|
||||||
|
:param session_id: (string) The ID of this session as returned from check_auth
|
||||||
|
:param key: (string) The key to store the data under
|
||||||
|
:param value: (any) The data to store
|
||||||
|
"""
|
||||||
|
sess = self._get_session_info(session_id)
|
||||||
|
sess.setdefault('serverdict', {})[key] = value
|
||||||
|
self._save_session(sess)
|
||||||
|
|
||||||
|
def get_session_data(self, session_id, key, default=None):
|
||||||
|
"""
|
||||||
|
Retrieve data stored with set_session_data
|
||||||
|
:param session_id: (string) The ID of this session as returned from check_auth
|
||||||
|
:param key: (string) The key to store the data under
|
||||||
|
:param default: (any) Value to return if the key has not been set
|
||||||
|
"""
|
||||||
|
sess = self._get_session_info(session_id)
|
||||||
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_password_auth(self, authdict, _):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
@ -455,11 +484,18 @@ class AuthHandler(BaseHandler):
|
|||||||
def _save_session(self, session):
|
def _save_session(self, session):
|
||||||
# TODO: Persistent storage
|
# TODO: Persistent storage
|
||||||
logger.debug("Saving session %s", session)
|
logger.debug("Saving session %s", session)
|
||||||
|
session["last_used"] = self.hs.get_clock().time_msec()
|
||||||
self.sessions[session["id"]] = session
|
self.sessions[session["id"]] = session
|
||||||
|
self._prune_sessions()
|
||||||
|
|
||||||
def _remove_session(self, session):
|
def _prune_sessions(self):
|
||||||
logger.debug("Removing session %s", session)
|
for sid, sess in self.sessions.items():
|
||||||
del self.sessions[session["id"]]
|
last_used = 0
|
||||||
|
if 'last_used' in sess:
|
||||||
|
last_used = sess['last_used']
|
||||||
|
now = self.hs.get_clock().time_msec()
|
||||||
|
if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
|
||||||
|
del self.sessions[sid]
|
||||||
|
|
||||||
def hash(self, password):
|
def hash(self, password):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
@ -43,7 +43,7 @@ class PasswordRestServlet(RestServlet):
|
|||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
authed, result, params = yield self.auth_handler.check_auth([
|
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
@ -139,7 +139,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
authed, result, params = yield self.auth_handler.check_auth(
|
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,6 +147,26 @@ class RegisterRestServlet(RestServlet):
|
|||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# have we already registered a user for this session
|
||||||
|
registered_user_id = self.auth_handler.get_session_data(
|
||||||
|
session_id, "registered_user_id", None
|
||||||
|
)
|
||||||
|
if registered_user_id is not None:
|
||||||
|
logger.info(
|
||||||
|
"Already registered user ID %r for this session",
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
|
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
||||||
|
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"user_id": registered_user_id,
|
||||||
|
"access_token": access_token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}))
|
||||||
|
|
||||||
# NB: This may be from the auth handler and NOT from the POST
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||||
@ -161,6 +181,12 @@ class RegisterRestServlet(RestServlet):
|
|||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# remember that we've now registered that user account, and with what
|
||||||
|
# user ID (since the user may not have specified)
|
||||||
|
self.auth_handler.set_session_data(
|
||||||
|
session_id, "registered_user_id", user_id
|
||||||
|
)
|
||||||
|
|
||||||
if result and LoginType.EMAIL_IDENTITY in result:
|
if result and LoginType.EMAIL_IDENTITY in result:
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
|
|
||||||
|
@ -22,9 +22,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
side_effect=lambda x: defer.succeed(self.appservice))
|
side_effect=lambda x: defer.succeed(self.appservice))
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth_result = (False, None, None)
|
self.auth_result = (False, None, None, None)
|
||||||
self.auth_handler = Mock(
|
self.auth_handler = Mock(
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result)
|
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||||
|
get_session_data=Mock(return_value=None)
|
||||||
)
|
)
|
||||||
self.registration_handler = Mock()
|
self.registration_handler = Mock()
|
||||||
self.identity_handler = Mock()
|
self.identity_handler = Mock()
|
||||||
@ -112,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.auth_result = (True, None, {
|
self.auth_result = (True, None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
})
|
}, None)
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, token))
|
self.registration_handler.register = Mock(return_value=(user_id, token))
|
||||||
|
|
||||||
(code, result) = yield self.servlet.on_POST(self.request)
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.auth_result = (True, None, {
|
self.auth_result = (True, None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
})
|
}, None)
|
||||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||||
d = self.servlet.on_POST(self.request)
|
d = self.servlet.on_POST(self.request)
|
||||||
return self.assertFailure(d, SynapseError)
|
return self.assertFailure(d, SynapseError)
|
||||||
|
Loading…
Reference in New Issue
Block a user