Merge pull request #1160 from matrix-org/rav/401_on_password_fail

Interactive Auth: Return 401 from for incorrect password
This commit is contained in:
Richard van der Hoff 2016-10-07 10:57:43 +01:00 committed by GitHub
commit 8681aff4f1

View File

@ -59,7 +59,6 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
if self.ldap_enabled: if self.ldap_enabled:
@ -149,13 +148,19 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
if authdict['type'] not in self.checkers: if authdict['type'] not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip) try:
if result: result = yield self.checkers[authdict['type']](authdict, clientip)
creds[authdict['type']] = result if result:
self._save_session(session) creds[authdict['type']] = result
self._save_session(session)
except LoginError, e:
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@ -164,6 +169,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -431,37 +437,40 @@ class AuthHandler(BaseHandler):
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
try: res = yield self._find_user_id_and_pwd_hash(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None:
defer.returnValue(res[0]) defer.returnValue(res[0])
except LoginError: defer.returnValue(None)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches. insensitively, but will return None if there are multiple inexact
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) elif len(user_infos) == 1:
# a single match (possibly not exact)
if len(user_infos) > 1: result = user_infos.popitem()
if user_id not in user_infos: elif user_id in user_infos:
logger.warn( # multiple matches, but one is exact
"Attempted to login as %s but it matches more than one user " result = (user_id, user_infos[user_id])
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else: else:
defer.returnValue(user_infos.popitem()) # multiple matches, none of them exact
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -475,34 +484,45 @@ class AuthHandler(BaseHandler):
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id
Raises: Raises:
LoginError if the password was incorrect LoginError if login fails
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(user_id) defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(user_id, password)
defer.returnValue(result)
if canonical_user_id:
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id, or None if unknown user / bad password
Raises:
LoginError if the password was incorrect
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(None)
defer.returnValue(user_id) defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password): def _ldap_simple_bind(self, server, localpart, password):