Port the Password Auth Providers module interface to the new generic interface (#10548)

Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
Azrenbeth 2021-10-13 12:21:52 +01:00 committed by GitHub
parent 732bbf6737
commit cdd308845b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 789 additions and 224 deletions

View file

@ -200,46 +200,13 @@ class AuthHandler:
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an
# import loop.
#
# TODO: refactor this class to separate the lower-level stuff that
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.password_auth_provider = hs.get_password_auth_provider()
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
if self._password_localdb_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
login_types.discard(LoginType.PASSWORD)
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -427,11 +394,10 @@ class AuthHandler:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
for t in self.password_auth_provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
@ -1038,7 +1004,25 @@ class AuthHandler:
Returns:
login types
"""
return self._supported_login_types
# Load any login types registered by modules
# This is stored in the password_auth_provider so this doesn't trigger
# any callbacks
types = list(self.password_auth_provider.get_supported_login_types().keys())
# This list should include PASSWORD if (either _password_localdb_enabled is
# true or if one of the modules registered it) AND _password_enabled is true
# Also:
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
if self._password_enabled:
types.insert(0, LoginType.PASSWORD)
elif self._password_localdb_enabled and self._password_enabled:
types.insert(0, LoginType.PASSWORD)
return types
async def validate_login(
self,
@ -1217,15 +1201,20 @@ class AuthHandler:
known_login_type = False
for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
# Check if login_type matches a type registered by one of the modules
# We don't need to remove LoginType.PASSWORD from the list if password login is
# disabled, since if that were the case then by this point we know that the
# login_type is not LoginType.PASSWORD
supported_login_types = self.password_auth_provider.get_supported_login_types()
# check if the login type being used is supported by a module
if login_type in supported_login_types:
# Make a note that this login type is supported by the server
known_login_type = True
# Get all the fields expected for this login types
login_fields = supported_login_types[login_type]
# go through the login submission and keep track of which required fields are
# provided/not provided
missing_fields = []
login_dict = {}
for f in login_fields:
@ -1233,6 +1222,7 @@ class AuthHandler:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
# raise an error if any of the expected fields for that login type weren't provided
if missing_fields:
raise SynapseError(
400,
@ -1240,10 +1230,15 @@ class AuthHandler:
% (login_type, missing_fields),
)
result = await provider.check_auth(username, login_type, login_dict)
# call all of the check_auth hooks for that login_type
# it will return a result once the first success is found (or None otherwise)
result = await self.password_auth_provider.check_auth(
username, login_type, login_dict
)
if result:
return result
# if no module managed to authenticate the user, then fallback to built in password based auth
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
@ -1282,11 +1277,16 @@ class AuthHandler:
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result
# call all of the check_3pid_auth callbacks
# Result will be from the first callback that returns something other than None
# If all the callbacks return None, then result is also set to None
result = await self.password_auth_provider.check_3pid_auth(
medium, address, password
)
if result:
return result
# if result is None then return (None, None)
return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@ -1365,13 +1365,12 @@ class AuthHandler:
user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# see if any modules want to know about this
await self.password_auth_provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
if user_info.token_id is not None:
@ -1398,12 +1397,11 @@ class AuthHandler:
user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
for token, _, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# see if any modules want to know about this
for token, _, device_id in tokens_and_devices:
await self.password_auth_provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@ -1811,40 +1809,228 @@ class MacaroonGenerator:
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
module_api = hs.get_module_api()
for module, config in hs.config.authproviders.password_providers:
load_single_legacy_password_auth_provider(
module=module, config=config, api=module_api
)
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
def load_single_legacy_password_auth_provider(
module: Type, config: JsonDict, api: ModuleApi
) -> None:
try:
provider = module(config=config, account_handler=api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
# The known hooks. If a module implements a method who's name appears in this set
# we'll want to register it
password_auth_provider_methods = {
"check_3pid_auth",
"on_logged_out",
}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We need to wrap check_password because its old form would return a boolean
# but we now want it to behave just like check_auth() and return the matrix id of
# the user if authentication succeeded or None otherwise
if f.__name__ == "check_password":
async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
matrix_user_id = api.get_qualified_user_id(username)
password = login_dict["password"]
is_valid = await f(matrix_user_id, password)
if is_valid:
return matrix_user_id, None
return None
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(username, login_type, login_dict)
if isinstance(result, str):
return result, None
return result
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(medium, address, password)
if isinstance(result, str):
return result, None
return result
return wrapped_check_3pid_auth
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# populate hooks with the implemented methods, wrapped with async_wrapper
hooks = {
hook: async_wrapper(getattr(provider, hook, None))
for hook in password_auth_provider_methods
}
supported_login_types = {}
# call get_supported_login_types and add that to the dict
g = getattr(provider, "get_supported_login_types", None)
if g is not None:
# Note the old module style also called get_supported_login_types at loading time
# and it is synchronous
supported_login_types.update(g())
auth_checkers = {}
# Legacy modules have a check_auth method which expects to be called with one of
# the keys returned by get_supported_login_types. New style modules register a
# dictionary of login_type->check_auth_method mappings
check_auth = async_wrapper(getattr(provider, "check_auth", None))
if check_auth is not None:
for login_type, fields in supported_login_types.items():
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
auth_checkers[(login_type, tuple(fields))] = check_auth
# if it has a "check_password" method then it should handle all auth checks
# with login type of LoginType.PASSWORD
check_password = async_wrapper(getattr(provider, "check_password", None))
if check_password is not None:
# need to use a tuple here for ("password",) not a list since lists aren't hashable
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
class PasswordAuthProvider:
"""
A class that the AuthHandler calls when authenticating users
It allows modules to provide alternative methods for authentication
"""
@classmethod
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
self._supported_login_types = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
def __str__(self) -> str:
return str(self._pp)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self._supported_login_types.get(login_type, []):
self._supported_login_types[login_type] = fields
else:
fields_currently_supported = self._supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@ -1852,20 +2038,15 @@ class PasswordProvider:
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
@ -1879,63 +2060,130 @@ class PasswordProvider:
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await check_auth(username, login_type, login_dict)
# Go through all callbacks for the login type until one returns with a value
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
try:
result = await callback(username, login_type, login_dict)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
return result
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
for callback in self.check_3pid_auth_callbacks:
try:
result = await callback(medium, address, password)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
return result
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
await maybe_awaitable(
g(
user_id=user_id,
device_id=device_id,
access_token=access_token,
)
)
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue