diff --git a/changelog.d/11790.feature b/changelog.d/11790.feature new file mode 100644 index 000000000..4a5cc8ec3 --- /dev/null +++ b/changelog.d/11790.feature @@ -0,0 +1 @@ +Add a module callback to set username at registration. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index e53abf640..ec8324d29 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -105,6 +105,68 @@ device ID), and the (now deactivated) access token. If multiple modules implement this callback, Synapse runs them all in order. +### `get_username_for_registration` + +_First introduced in Synapse v1.52.0_ + +```python +async def get_username_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a username to set for the user +being registered by returning it as a string, or `None` if it doesn't wish to force a +username for this user. If a username is returned, it will be used as the local part of a +user's full Matrix ID (e.g. it's `alice` in `@alice:example.com`). + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. + +The first dictionary contains the results of the [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +flow followed by the user. Its keys are the identifiers of every step involved in the flow, +associated with either a boolean value indicating whether the step was correctly completed, +or additional information (e.g. email address, phone number...). A list of most existing +identifiers can be found in the [Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#authentication-types). +Here's an example featuring all currently supported keys: + +```python +{ + "m.login.dummy": True, # Dummy authentication + "m.login.terms": True, # User has accepted the terms of service for the homeserver + "m.login.recaptcha": True, # User has completed the recaptcha challenge + "m.login.email.identity": { # User has provided and verified an email address + "medium": "email", + "address": "alice@example.com", + "validated_at": 1642701357084, + }, + "m.login.msisdn": { # User has provided and verified a phone number + "medium": "msisdn", + "address": "33123456789", + "validated_at": 1642701357084, + }, + "org.matrix.msc3231.login.registration_token": "sometoken", # User has registered through the flow described in MSC3231 +} +``` + +The second dictionary contains the parameters provided by the user's client in the request +to `/_matrix/client/v3/register`. See the [Matrix specification](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3register) +for a complete list of these parameters. + +If the module cannot, or does not wish to, generate a username for this user, it must +return `None`. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback return `None`, +the username provided by the user is used, if any (otherwise one is automatically +generated). + + ## Example The example module below implements authentication checkers for two different login types: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bd1a32256..e32c93e23 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] ], ] +GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] class PasswordAuthProvider: @@ -2072,6 +2076,9 @@ class PasswordAuthProvider: # lists of callbacks self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + self.get_username_for_registration_callbacks: List[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = [] # Mapping from login type to login parameters self._supported_login_types: Dict[str, Iterable[str]] = {} @@ -2086,6 +2093,9 @@ class PasswordAuthProvider: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2130,6 +2140,11 @@ class PasswordAuthProvider: # Add the new method to the list of auth_checker_callbacks for this login type self.auth_checker_callbacks.setdefault(login_type, []).append(callback) + if get_username_for_registration is not None: + self.get_username_for_registration_callbacks.append( + get_username_for_registration, + ) + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -2285,3 +2300,46 @@ class PasswordAuthProvider: except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue + + async def get_username_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the username to use when registering the user, using the credentials + and parameters provided during the UIA flow. + + Stops at the first callback that returns a string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + The localpart to use when registering this user, or None if no module + returned a localpart. + """ + for callback in self.get_username_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_username_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_username_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 662e60bc3..788b2e47d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -71,6 +71,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_USERNAME_FOR_REGISTRATION_CALLBACK, ON_LOGGED_OUT_CALLBACK, AuthHandler, ) @@ -177,6 +178,7 @@ class ModuleApi: self._presence_stream = hs.get_event_sources().sources.presence self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() + self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self.custom_template_dir = hs.config.server.custom_template_directory @@ -310,6 +312,9 @@ class ModuleApi: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -319,6 +324,7 @@ class ModuleApi: check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, auth_checkers=auth_checkers, + get_username_for_registration=get_username_for_registration, ) def register_background_update_controller_callbacks( @@ -1202,6 +1208,22 @@ class ModuleApi: """ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) + async def check_username(self, username: str) -> None: + """Checks if the provided username uses the grammar defined in the Matrix + specification, and is already being used by an existing user. + + Added in Synapse v1.52.0. + + Args: + username: The username to check. This is the local part of the user's full + Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com". + + Raises: + SynapseError with the errcode "M_USER_IN_USE" if the username is already in + use. + """ + await self._registration_handler.check_username(username) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c59dae7c0..e3492f9f9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self.password_auth_provider = hs.get_password_auth_provider() self._registration_enabled = self.hs.config.registration.enable_registration self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None @@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = params.get("username", None) + desired_username = await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) + ) + + if desired_username is None: + desired_username = params.get("username", None) + guest_access_token = params.get("guest_access_token", None) if desired_username is not None: diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 2add72b28..94809cb8b 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,10 +20,11 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.api.constants import LoginType from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout -from synapse.types import JsonDict +from synapse.rest.client import devices, login, logout, register +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeChannel @@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login.register_servlets, devices.register_servlets, logout.register_servlets, + register.register_servlets, ] def setUp(self): @@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) + def test_username(self): + """Tests that the get_username_for_registration callback can define the username + of a user when registering. + """ + self._setup_get_username_for_registration() + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + def test_username_uia(self): + """Tests that the get_username_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_username_for_registration() + + # Initiate the UIA flow. + username = "rin" + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + + def _setup_get_username_for_registration(self) -> Mock: + """Registers a get_username_for_registration callback that appends "-foo" to the + username the client is trying to register. + """ + + async def get_username_for_registration(uia_results, params): + self.assertIn(LoginType.DUMMY, uia_results) + username = params["username"] + return username + "-foo" + + m = Mock(side_effect=get_username_for_registration) + + password_auth_provider = self.hs.get_password_auth_provider() + password_auth_provider.get_username_for_registration_callbacks.append(m) + + return m + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result)