Merge pull request #928 from matrix-org/rav/refactor_login

Refactor login flow
This commit is contained in:
Richard van der Hoff 2016-07-18 16:12:35 +01:00 committed by GitHub
commit 93efcb8526
2 changed files with 82 additions and 65 deletions

View File

@ -230,7 +230,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@ -240,11 +239,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
if not (yield self._check_password(user_id, password)): return self._check_password(user_id, password)
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -348,67 +343,66 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks def validate_password_login(self, user_id, password):
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): User ID user_id (str): complete @user:id
password (str): Password password (str): Password
Returns: Returns:
A tuple of: defer.Deferred: (str) canonical user id
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem accessing the database
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
return self._check_password(user_id, password)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS) machanism (e.g. CAS), and the user_id converted to the canonical case.
Args: Args:
user_id (str): User ID user_id (str): canonical User ID
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def does_user_exist(self, user_id): def check_user_exists(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
"""
try: try:
yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True) defer.returnValue(res[0])
except LoginError: except LoginError:
defer.returnValue(False) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
@ -438,27 +432,45 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
""" """Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns: Returns:
True if the user_id successfully authenticated (str) the canonical_user_id
Raises:
LoginError if the password was incorrect
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(True) defer.returnValue(user_id)
valid_local_password = yield self._check_local_password(user_id, password) result = yield self._check_local_password(user_id, password)
if valid_local_password: defer.returnValue(result)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
try: """Authenticate a user against the local password database.
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash)) user_id is checked case insensitively, but will throw if there are
except LoginError: multiple inexact matches.
defer.returnValue(False)
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
@ -570,7 +582,7 @@ class AuthHandler(BaseHandler):
) )
# check for existing account, if none exists, create one # check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)): if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation # query user metadata for account creation
query = "({prop}={value})".format( query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'], prop=self.ldap_attributes['uid'],

View File

@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string() ).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id = yield auth_handler.validate_password_login(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"],
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
user_id, access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(user_id)
) )
result = { result = {
@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if user_exists: if registered_user_id:
user_id, access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": registered_user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if user_exists: if registered_user_id:
user_id, access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": registered_user_id,
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if not user_exists: if not registered_user_id:
user_id, _ = ( registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
login_token = auth_handler.generate_short_term_login_token(user_id) login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)