mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Create a PasswordProvider
wrapper object (#8849)
The idea here is to abstract out all the conditional code which tests which methods a given password provider has, to provide a consistent interface.
This commit is contained in:
parent
edb3d3f827
commit
d3ed93504b
1
changelog.d/8849.misc
Normal file
1
changelog.d/8849.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor `password_auth_provider` support code.
|
@ -1,6 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -25,6 +26,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
|
|||||||
# better way to break the loop
|
# better way to break the loop
|
||||||
account_handler = ModuleApi(hs, self)
|
account_handler = ModuleApi(hs, self)
|
||||||
|
|
||||||
self.password_providers = []
|
self.password_providers = [
|
||||||
for module, config in hs.config.password_providers:
|
PasswordProvider.load(module, config, account_handler)
|
||||||
try:
|
for module, config in hs.config.password_providers
|
||||||
self.password_providers.append(
|
]
|
||||||
module(config=config, account_handler=account_handler)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error while initializing %r: %s", module, e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info("Extra password_providers: %r", self.password_providers)
|
logger.info("Extra password_providers: %s", self.password_providers)
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
|
|||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
login_type = login_submission.get("type")
|
login_type = login_submission.get("type")
|
||||||
|
if not isinstance(login_type, str):
|
||||||
|
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
||||||
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
||||||
@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
|
|||||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
login_type = login_submission.get("type")
|
login_type = login_submission.get("type")
|
||||||
|
# we already checked that we have a valid login type
|
||||||
|
assert isinstance(login_type, str)
|
||||||
|
|
||||||
known_login_type = False
|
known_login_type = False
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
|
||||||
known_login_type = True
|
|
||||||
# we've already checked that there is a (valid) password field
|
|
||||||
is_valid = await provider.check_password(
|
|
||||||
qualified_user_id, login_submission["password"]
|
|
||||||
)
|
|
||||||
if is_valid:
|
|
||||||
return qualified_user_id, None
|
|
||||||
|
|
||||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
|
||||||
provider, "check_auth"
|
|
||||||
):
|
|
||||||
# this password provider doesn't understand custom login types
|
|
||||||
continue
|
|
||||||
|
|
||||||
supported_login_types = provider.get_supported_login_types()
|
supported_login_types = provider.get_supported_login_types()
|
||||||
if login_type not in supported_login_types:
|
if login_type not in supported_login_types:
|
||||||
# this password provider doesn't understand this login type
|
# this password provider doesn't understand this login type
|
||||||
@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
result = await provider.check_auth(username, login_type, login_dict)
|
result = await provider.check_auth(username, login_type, login_dict)
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, str):
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||||
@ -1083,18 +1068,8 @@ class AuthHandler(BaseHandler):
|
|||||||
unsuccessful, `user_id` and `callback` are both `None`.
|
unsuccessful, `user_id` and `callback` are both `None`.
|
||||||
"""
|
"""
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_3pid_auth"):
|
|
||||||
# This function is able to return a deferred that either
|
|
||||||
# resolves None, meaning authentication failure, or upon
|
|
||||||
# success, to a str (which is the user_id) or a tuple of
|
|
||||||
# (user_id, callback_func), where callback_func should be run
|
|
||||||
# after we've finished everything else
|
|
||||||
result = await provider.check_3pid_auth(medium, address, password)
|
result = await provider.check_3pid_auth(medium, address, password)
|
||||||
if result:
|
if result:
|
||||||
# Check if the return value is a str or a tuple
|
|
||||||
if isinstance(result, str):
|
|
||||||
# If it's a str, set callback function to None
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
await provider.on_logged_out(
|
||||||
# This might return an awaitable, if it does block the log out
|
|
||||||
# until it completes.
|
|
||||||
result = provider.on_logged_out(
|
|
||||||
user_id=user_info.user_id,
|
user_id=user_info.user_id,
|
||||||
device_id=user_info.device_id,
|
device_id=user_info.device_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(result):
|
|
||||||
await result
|
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info.token_id is not None:
|
if user_info.token_id is not None:
|
||||||
@ -1191,7 +1161,6 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
|
||||||
for token, token_id, device_id in tokens_and_devices:
|
for token, token_id, device_id in tokens_and_devices:
|
||||||
await provider.on_logged_out(
|
await provider.on_logged_out(
|
||||||
user_id=user_id, device_id=device_id, access_token=token
|
user_id=user_id, device_id=device_id, access_token=token
|
||||||
@ -1519,3 +1488,127 @@ class MacaroonGenerator:
|
|||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordProvider:
|
||||||
|
"""Wrapper for a password auth provider module
|
||||||
|
|
||||||
|
This class abstracts out all of the backwards-compatibility hacks for
|
||||||
|
password providers, to provide a consistent interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||||
|
try:
|
||||||
|
pp = module(config=config, account_handler=module_api)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error while initializing %r: %s", module, e)
|
||||||
|
raise
|
||||||
|
return cls(pp, module_api)
|
||||||
|
|
||||||
|
def __init__(self, pp, module_api: ModuleApi):
|
||||||
|
self._pp = pp
|
||||||
|
self._module_api = module_api
|
||||||
|
|
||||||
|
self._supported_login_types = {}
|
||||||
|
|
||||||
|
# grandfather in check_password support
|
||||||
|
if hasattr(self._pp, "check_password"):
|
||||||
|
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||||
|
|
||||||
|
g = getattr(self._pp, "get_supported_login_types", None)
|
||||||
|
if g:
|
||||||
|
self._supported_login_types.update(g())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self._pp)
|
||||||
|
|
||||||
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
|
"""Get the login types supported by this password provider
|
||||||
|
|
||||||
|
Returns a map from a login type identifier (such as m.login.password) to an
|
||||||
|
iterable giving the fields which must be provided by the user in the submission
|
||||||
|
to the /login API.
|
||||||
|
|
||||||
|
This wrapper adds m.login.password to the list if the underlying password
|
||||||
|
provider supports the check_password() api.
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
async def check_auth(
|
||||||
|
self, username: str, login_type: str, login_dict: JsonDict
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
"""Check if the user has presented valid login credentials
|
||||||
|
|
||||||
|
This wrapper also calls check_password() if the underlying password provider
|
||||||
|
supports the check_password() api and the login type is m.login.password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: user id presented by the client. Either an MXID or an unqualified
|
||||||
|
username.
|
||||||
|
|
||||||
|
login_type: the login type being attempted - one of the types returned by
|
||||||
|
get_supported_login_types()
|
||||||
|
|
||||||
|
login_dict: the dictionary of login secrets passed by the client.
|
||||||
|
|
||||||
|
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
|
||||||
|
user, and `callback` is an optional callback which will be called with the
|
||||||
|
result from the /login call (including access_token, device_id, etc.)
|
||||||
|
"""
|
||||||
|
# first grandfather in a call to check_password
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
g = getattr(self._pp, "check_password", None)
|
||||||
|
if g:
|
||||||
|
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||||
|
is_valid = await self._pp.check_password(
|
||||||
|
qualified_user_id, login_dict["password"]
|
||||||
|
)
|
||||||
|
if is_valid:
|
||||||
|
return qualified_user_id, None
|
||||||
|
|
||||||
|
g = getattr(self._pp, "check_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
result = await g(username, login_type, login_dict)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def check_3pid_auth(
|
||||||
|
self, medium: str, address: str, password: str
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
g = getattr(self._pp, "check_3pid_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# This function is able to return a deferred that either
|
||||||
|
# resolves None, meaning authentication failure, or upon
|
||||||
|
# success, to a str (which is the user_id) or a tuple of
|
||||||
|
# (user_id, callback_func), where callback_func should be run
|
||||||
|
# after we've finished everything else
|
||||||
|
result = await g(medium, address, password)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def on_logged_out(
|
||||||
|
self, user_id: str, device_id: Optional[str], access_token: str
|
||||||
|
) -> None:
|
||||||
|
g = getattr(self._pp, "on_logged_out", None)
|
||||||
|
if not g:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This might return an awaitable, if it does block the log out
|
||||||
|
# until it completes.
|
||||||
|
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||||||
# first delete should give a 401
|
# first delete should give a 401
|
||||||
channel = self._delete_device(tok1, "dev2")
|
channel = self._delete_device(tok1, "dev2")
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
# there are no valid flows here!
|
# m.login.password UIA is permitted because the auth provider allows it,
|
||||||
self.assertEqual(channel.json_body["flows"], [])
|
# even though the localdb does not.
|
||||||
|
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user