Merge pull request #2620 from matrix-org/rav/auth_non_password

Let password auth providers handle arbitrary login types
This commit is contained in:
Richard van der Hoff 2017-11-01 16:45:33 +00:00 committed by GitHub
commit 846a94fbc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 33 deletions

View File

@ -30,22 +30,55 @@ Password auth provider classes must provide the following methods:
and a ``synapse.handlers.auth._AccountHandler`` object which allows the and a ``synapse.handlers.auth._AccountHandler`` object which allows the
password provider to check if accounts exist and/or create new ones. password provider to check if accounts exist and/or create new ones.
``someprovider.check_password``\(*user_id*, *password*)
This is the method that actually does the work. It is passed a qualified
``@localpart:domain`` user id, and the password provided by the user.
The method should return a Twisted ``Deferred`` object, which resolves to
``True`` if authentication is successful, and ``False`` if not.
Optional methods Optional methods
---------------- ----------------
Password provider classes may optionally provide the following methods. Password auth provider classes may optionally provide the following methods.
*class* ``SomeProvider.get_db_schema_files()`` *class* ``SomeProvider.get_db_schema_files``\()
This method, if implemented, should return an Iterable of ``(name, This method, if implemented, should return an Iterable of ``(name,
stream)`` pairs of database schema files. Each file is applied in turn at stream)`` pairs of database schema files. Each file is applied in turn at
initialisation, and a record is then made in the database so that it is initialisation, and a record is then made in the database so that it is
not re-applied on the next start. not re-applied on the next start.
``someprovider.get_supported_login_types``\()
This method, if implemented, should return a ``dict`` mapping from a login
type identifier (such as ``m.login.password``) to an iterable giving the
fields which must be provided by the user in the submission to the
``/login`` api. These fields are passed in the ``login_dict`` dictionary
to ``check_auth``.
For example, if a password auth provider wants to implement a custom login
type of ``com.example.custom_login``, where the client is expected to pass
the fields ``secret1`` and ``secret2``, the provider should implement this
method and return the following dict::
{"com.example.custom_login": ("secret1", "secret2")}
``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
This method is the one that does the real work. If implemented, it will be
called for each login attempt where the login type matches one of the keys
returned by ``get_supported_login_types``.
It is passed the (possibly UNqualified) ``user`` provided by the client,
the login type, and a dictionary of login secrets passed by the client.
The method should return a Twisted ``Deferred`` object, which resolves to
the canonical ``@localpart:domain`` user id if authentication is successful,
and ``None`` if not.
``someprovider.check_password``\(*user_id*, *password*)
This method provides a simpler interface than ``get_supported_login_types``
and ``check_auth`` for password auth providers that just want to provide a
mechanism for validating ``m.login.password`` logins.
Iif implemented, it will be called to check logins with an
``m.login.password`` login type. It is passed a qualified
``@localpart:domain`` user id, and the password provided by the user.
The method should return a Twisted ``Deferred`` object, which resolves to
``True`` if authentication is successful, and ``False`` if not.

View File

@ -81,6 +81,11 @@ class AuthHandler(BaseHandler):
login_types = set() login_types = set()
if self._password_enabled: if self._password_enabled:
login_types.add(LoginType.PASSWORD) login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
login_types.update(
provider.get_supported_login_types().keys()
)
self._supported_login_types = frozenset(login_types) self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -501,14 +506,14 @@ class AuthHandler(BaseHandler):
return self._supported_login_types return self._supported_login_types
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_login(self, user_id, login_submission): def validate_login(self, username, login_submission):
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate Also used by the user-interactive auth flow to validate
m.login.password auth types. m.login.password auth types.
Args: Args:
user_id (str): user_id supplied by the user username (str): username supplied by the user
login_submission (dict): the whole of the login submission login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields) (including 'type' and other relevant fields)
Returns: Returns:
@ -519,33 +524,82 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
if not user_id.startswith('@'): if username.startswith('@'):
user_id = UserID( qualified_user_id = username
user_id, self.hs.hostname else:
qualified_user_id = UserID(
username, self.hs.hostname
).to_string() ).to_string()
login_type = login_submission.get("type") login_type = login_submission.get("type")
known_login_type = False
if login_type != LoginType.PASSWORD: # special case to check for "password" for the check_password interface
raise SynapseError(400, "Bad login type.") # for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled: if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.") raise SynapseError(400, "Password login has been disabled.")
if "password" not in login_submission: if not password:
raise SynapseError(400, "Missing parameter: password") raise SynapseError(400, "Missing parameter: password")
password = login_submission["password"]
for provider in self.password_providers: for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password) if (hasattr(provider, "check_password")
and login_type == LoginType.PASSWORD):
known_login_type = True
is_valid = yield provider.check_password(
qualified_user_id, password,
)
if is_valid: if is_valid:
defer.returnValue(user_id) defer.returnValue(qualified_user_id)
if (not hasattr(provider, "get_supported_login_types")
or not hasattr(provider, "check_auth")):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
known_login_type = True
login_fields = supported_login_types[login_type]
missing_fields = []
login_dict = {}
for f in login_fields:
if f not in login_submission:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
400, "Missing parameters for login type %s: %s" % (
login_type,
missing_fields,
),
)
returned_user_id = yield provider.check_auth(
username, login_type, login_dict,
)
if returned_user_id:
defer.returnValue(returned_user_id)
if login_type == LoginType.PASSWORD:
known_login_type = True
canonical_user_id = yield self._check_local_password( canonical_user_id = yield self._check_local_password(
user_id, password, qualified_user_id, password,
) )
if canonical_user_id: if canonical_user_id:
defer.returnValue(canonical_user_id) defer.returnValue(canonical_user_id)
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note # unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors # that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway. # into a 401 anyway.
@ -773,11 +827,31 @@ class _AccountHandler(object):
self._check_user_exists = check_user_exists self._check_user_exists = check_user_exists
def check_user_exists(self, user_id): def get_qualified_user_id(self, username):
"""Check if user exissts. """Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns: Returns:
Deferred(bool) str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
""" """
return self._check_user_exists(user_id) return self._check_user_exists(user_id)