support custom login types for validating users

Wire the custom login type support from password providers into the UI-auth
user-validation flows.
This commit is contained in:
Richard van der Hoff 2017-12-04 16:49:40 +00:00
parent cc58e177f3
commit da1010c83a

View File

@ -49,7 +49,6 @@ class AuthHandler(BaseHandler):
""" """
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha, LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn, LoginType.MSISDN: self._check_msisdn,
@ -78,15 +77,20 @@ class AuthHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
login_types = set() # we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
login_types = []
if self._password_enabled: if self._password_enabled:
login_types.add(LoginType.PASSWORD) login_types.append(LoginType.PASSWORD)
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"): if hasattr(provider, "get_supported_login_types"):
login_types.update( for t in provider.get_supported_login_types().keys():
provider.get_supported_login_types().keys() if t not in login_types:
) login_types.append(t)
self._supported_login_types = frozenset(login_types) self._supported_login_types = login_types
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip): def validate_user_via_ui_auth(self, requester, request_body, clientip):
@ -116,14 +120,27 @@ class AuthHandler(BaseHandler):
a different user to `requester` a different user to `requester`
""" """
# we only support password login here # build a list of supported flows
flows = [[LoginType.PASSWORD]] flows = [
[login_type] for login_type in self._supported_login_types
]
result, params, _ = yield self.check_auth( result, params, _ = yield self.check_auth(
flows, request_body, clientip, flows, request_body, clientip,
) )
user_id = result[LoginType.PASSWORD] # find the completed login type
for login_type in self._supported_login_types:
if login_type not in result:
continue
user_id = result[login_type]
break
else:
# this can't happen
raise Exception(
"check_auth returned True but no successful login type",
)
# check that the UI auth matched the access token # check that the UI auth matched the access token
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
@ -210,14 +227,12 @@ class AuthHandler(BaseHandler):
errordict = {} errordict = {}
if 'type' in authdict: if 'type' in authdict:
login_type = authdict['type'] login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
try: try:
result = yield self.checkers[login_type](authdict, clientip) result = yield self._check_auth_dict(authdict, clientip)
if result: if result:
creds[login_type] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
except LoginError, e: except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY: if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new # riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it # validation token (thus sending a new email) each time it
@ -226,7 +241,7 @@ class AuthHandler(BaseHandler):
# #
# Grandfather in the old behaviour for now to avoid # Grandfather in the old behaviour for now to avoid
# breaking old riot deployments. # breaking old riot deployments.
raise e raise
# this step failed. Merge the error dict into the response # this step failed. Merge the error dict into the response
# so that the client can have another go. # so that the client can have another go.
@ -323,17 +338,35 @@ class AuthHandler(BaseHandler):
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_auth_dict(self, authdict, clientip):
if "user" not in authdict or "password" not in authdict: """Attempt to validate the auth dict provided by a client
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"] Args:
password = authdict["password"] authdict (object): auth dict provided by the client
clientip (str): IP address of the client
(canonical_id, callback) = yield self.validate_login(user_id, { Returns:
"type": LoginType.PASSWORD, Deferred: result of the stage verification.
"password": password,
}) Raises:
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = authdict['type']
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker(authdict, clientip)
defer.returnValue(res)
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
defer.returnValue(canonical_id) defer.returnValue(canonical_id)
@defer.inlineCallbacks @defer.inlineCallbacks