Let password auth providers handle arbitrary login types

Provide a hook where password auth providers can say they know about other
login types, and get passed the relevant parameters
This commit is contained in:
Richard van der Hoff 2017-10-31 10:43:57 +00:00
parent a72e4e3e28
commit 3cd6b22c7b
2 changed files with 140 additions and 33 deletions

View file

@ -82,6 +82,11 @@ class AuthHandler(BaseHandler):
login_types = set()
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
login_types.update(
provider.get_supported_login_types().keys()
)
self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks
@ -504,14 +509,14 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, user_id, login_submission):
def validate_login(self, username, login_submission):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
user_id (str): user_id supplied by the user
username (str): username supplied by the user
login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@ -522,32 +527,81 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
if username.startswith('@'):
qualified_user_id = username
else:
qualified_user_id = UserID(
username, self.hs.hostname
).to_string()
login_type = login_submission.get("type")
known_login_type = False
if login_type != LoginType.PASSWORD:
raise SynapseError(400, "Bad login type.")
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if "password" not in login_submission:
raise SynapseError(400, "Missing parameter: password")
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not password:
raise SynapseError(400, "Missing parameter: password")
password = login_submission["password"]
for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id)
if (hasattr(provider, "check_password")
and login_type == LoginType.PASSWORD):
known_login_type = True
is_valid = yield provider.check_password(
qualified_user_id, password,
)
if is_valid:
defer.returnValue(qualified_user_id)
canonical_user_id = yield self._check_local_password(
user_id, password,
)
if (not hasattr(provider, "get_supported_login_types")
or not hasattr(provider, "check_auth")):
# this password provider doesn't understand custom login types
continue
if canonical_user_id:
defer.returnValue(canonical_user_id)
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
known_login_type = True
login_fields = supported_login_types[login_type]
missing_fields = []
login_dict = {}
for f in login_fields:
if f not in login_submission:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
400, "Missing parameters for login type %s: %s" % (
login_type,
missing_fields,
),
)
returned_user_id = yield provider.check_auth(
username, login_type, login_dict,
)
if returned_user_id:
defer.returnValue(returned_user_id)
if login_type == LoginType.PASSWORD:
known_login_type = True
canonical_user_id = yield self._check_local_password(
qualified_user_id, password,
)
if canonical_user_id:
defer.returnValue(canonical_user_id)
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@ -731,11 +785,31 @@ class _AccountHandler(object):
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns:
Deferred(bool)
str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._check_user_exists(user_id)