mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-18 12:44:20 -05:00
Make registration idempotent: if you specify the same session, make it give you an access token for the user that was registered on previous uses of that session. Tweak the UI auth layer to not delete sessions when their auth has completed and hence expire themn so they don't hang around until server restart. Allow server-side data to be associated with UI auth sessions.
This commit is contained in:
parent
add89a03a6
commit
c12b9d719a
@ -27,6 +27,7 @@ import logging
|
|||||||
import bcrypt
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
import simplejson
|
import simplejson
|
||||||
|
import time
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
|
SESSION_EXPIRE_SECS = 48 * 60 * 60
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
@ -66,15 +68,18 @@ class AuthHandler(BaseHandler):
|
|||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
clientip (str): The IP address of the client.
|
clientip (str): The IP address of the client.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (authed, dict, dict) where authed is true if the client
|
A tuple of (authed, dict, dict, session_id) where authed is true if
|
||||||
has successfully completed an auth flow. If it is true, the first
|
the client has successfully completed an auth flow. If it is true
|
||||||
dict contains the authenticated credentials of each stage.
|
the first dict contains the authenticated credentials of each stage.
|
||||||
|
|
||||||
If authed is false, the first dictionary is the server response to
|
If authed is false, the first dictionary is the server response to
|
||||||
the login request and should be passed back to the client.
|
the login request and should be passed back to the client.
|
||||||
|
|
||||||
In either case, the second dict contains the parameters for this
|
In either case, the second dict contains the parameters for this
|
||||||
request (which may have been given only in a previous call).
|
request (which may have been given only in a previous call).
|
||||||
|
|
||||||
|
session_id is the ID of this session, either passed in by the client
|
||||||
|
or assigned by the call to check_auth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authdict = None
|
authdict = None
|
||||||
@ -103,7 +108,10 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(False, self._auth_dict_for_flows(flows, session), clientdict)
|
(
|
||||||
|
False, self._auth_dict_for_flows(flows, session),
|
||||||
|
clientdict, session['id']
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'creds' not in session:
|
if 'creds' not in session:
|
||||||
@ -122,12 +130,11 @@ class AuthHandler(BaseHandler):
|
|||||||
for f in flows:
|
for f in flows:
|
||||||
if len(set(f) - set(creds.keys())) == 0:
|
if len(set(f) - set(creds.keys())) == 0:
|
||||||
logger.info("Auth completed with creds: %r", creds)
|
logger.info("Auth completed with creds: %r", creds)
|
||||||
self._remove_session(session)
|
defer.returnValue((True, creds, clientdict, session['id']))
|
||||||
defer.returnValue((True, creds, clientdict))
|
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
ret['completed'] = creds.keys()
|
ret['completed'] = creds.keys()
|
||||||
defer.returnValue((False, ret, clientdict))
|
defer.returnValue((False, ret, clientdict, session['id']))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||||
@ -154,6 +161,29 @@ class AuthHandler(BaseHandler):
|
|||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
def set_session_data(self, session_id, key, value):
|
||||||
|
"""
|
||||||
|
Store a key-value pair into the sessions data associated with this
|
||||||
|
request. This data is stored server-side and cannot be modified by
|
||||||
|
the client.
|
||||||
|
:param session_id: (string) The ID of this session as returned from check_auth
|
||||||
|
:param key: (string) The key to store the data under
|
||||||
|
:param value: (any) The data to store
|
||||||
|
"""
|
||||||
|
sess = self._get_session_info(session_id)
|
||||||
|
sess.setdefault('serverdict', {})[key] = value
|
||||||
|
self._save_session(sess)
|
||||||
|
|
||||||
|
def get_session_data(self, session_id, key, default=None):
|
||||||
|
"""
|
||||||
|
Retrieve data stored with set_session_data
|
||||||
|
:param session_id: (string) The ID of this session as returned from check_auth
|
||||||
|
:param key: (string) The key to store the data under
|
||||||
|
:param default: (any) Value to return if the key has not been set
|
||||||
|
"""
|
||||||
|
sess = self._get_session_info(session_id)
|
||||||
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_password_auth(self, authdict, _):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
@ -263,7 +293,7 @@ class AuthHandler(BaseHandler):
|
|||||||
if not session_id:
|
if not session_id:
|
||||||
# create a new session
|
# create a new session
|
||||||
while session_id is None or session_id in self.sessions:
|
while session_id is None or session_id in self.sessions:
|
||||||
session_id = stringutils.random_string(24)
|
session_id = stringutils.random_string_with_symbols(24)
|
||||||
self.sessions[session_id] = {
|
self.sessions[session_id] = {
|
||||||
"id": session_id,
|
"id": session_id,
|
||||||
}
|
}
|
||||||
@ -455,11 +485,17 @@ class AuthHandler(BaseHandler):
|
|||||||
def _save_session(self, session):
|
def _save_session(self, session):
|
||||||
# TODO: Persistent storage
|
# TODO: Persistent storage
|
||||||
logger.debug("Saving session %s", session)
|
logger.debug("Saving session %s", session)
|
||||||
|
session["last_used"] = time.time()
|
||||||
self.sessions[session["id"]] = session
|
self.sessions[session["id"]] = session
|
||||||
|
self._prune_sessions()
|
||||||
|
|
||||||
def _remove_session(self, session):
|
def _prune_sessions(self):
|
||||||
logger.debug("Removing session %s", session)
|
for sid,sess in self.sessions.items():
|
||||||
del self.sessions[session["id"]]
|
last_used = 0
|
||||||
|
if 'last_used' in sess:
|
||||||
|
last_used = sess['last_used']
|
||||||
|
if last_used < time.time() - AuthHandler.SESSION_EXPIRE_SECS:
|
||||||
|
del self.sessions[sid]
|
||||||
|
|
||||||
def hash(self, password):
|
def hash(self, password):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
@ -139,7 +139,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
authed, result, params = yield self.auth_handler.check_auth(
|
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,6 +147,24 @@ class RegisterRestServlet(RestServlet):
|
|||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# have we already registered a user for this session
|
||||||
|
registered_user_id = self.auth_handler.get_session_data(
|
||||||
|
session_id, "registered_user_id", None
|
||||||
|
)
|
||||||
|
if registered_user_id is not None:
|
||||||
|
logger.info(
|
||||||
|
"Already registered user ID %r for this session",
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
|
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
||||||
|
refresh_token = yield self.auth_handler.issue_refresh_token(registered_user_id)
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"user_id": registered_user_id,
|
||||||
|
"access_token": access_token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}))
|
||||||
|
|
||||||
# NB: This may be from the auth handler and NOT from the POST
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||||
@ -161,6 +179,13 @@ class RegisterRestServlet(RestServlet):
|
|||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# remember that we've now registered that user account, and with what
|
||||||
|
# user ID (since the user may not have specified)
|
||||||
|
logger.info("%r", body)
|
||||||
|
self.auth_handler.set_session_data(
|
||||||
|
session_id, "registered_user_id", user_id
|
||||||
|
)
|
||||||
|
|
||||||
if result and LoginType.EMAIL_IDENTITY in result:
|
if result and LoginType.EMAIL_IDENTITY in result:
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user