Allow modules to set a display name on registration (#12009)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Brendan Abolivier 2022-02-17 17:54:16 +01:00 committed by GitHub
parent da0e9f8efd
commit 707049c6ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 195 additions and 34 deletions

View File

@ -0,0 +1 @@
Enable modules to set a custom display name when registering a user.

View File

@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`, any of the subsequent implementations of this callback. If every callback returns `None`,
the authentication is denied. the authentication is denied.
### `on_logged_out` ### `on_logged_out`
@ -162,10 +162,38 @@ return `None`.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`, any of the subsequent implementations of this callback. If every callback returns `None`,
the username provided by the user is used, if any (otherwise one is automatically the username provided by the user is used, if any (otherwise one is automatically
generated). generated).
### `get_displayname_for_registration`
_First introduced in Synapse v1.54.0_
```python
async def get_displayname_for_registration(
uia_results: Dict[str, Any],
params: Dict[str, Any],
) -> Optional[str]
```
Called when registering a new user. The module can return a display name to set for the
user being registered by returning it as a string, or `None` if it doesn't wish to force a
display name for this user.
This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
has been completed by the user. It is not called when registering a user via SSO. It is
passed two dictionaries, which include the information that the user has provided during
the registration process. These dictionaries are identical to the ones passed to
[`get_username_for_registration`](#get_username_for_registration), so refer to the
documentation of this callback for more information about them.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback returns `None`,
the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
## `is_3pid_allowed` ## `is_3pid_allowed`
_First introduced in Synapse v1.53.0_ _First introduced in Synapse v1.53.0_
@ -196,7 +224,6 @@ The example module below implements authentication checkers for two different lo
- Expects a `password` field to be sent to `/login` - Expects a `password` field to be sent to `/login`
- Is checked by the method: `self.check_pass` - Is checked by the method: `self.check_pass`
```python ```python
from typing import Awaitable, Callable, Optional, Tuple from typing import Awaitable, Callable, Optional, Tuple

View File

@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict], [JsonDict, JsonDict],
Awaitable[Optional[str]], Awaitable[Optional[str]],
] ]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[ self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = [] ] = []
self.get_displayname_for_registration_callbacks: List[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters # Mapping from login type to login parameters
@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
get_username_for_registration: Optional[ get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None, ] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None: ) -> None:
# Register check_3pid_auth callback # Register check_3pid_auth callback
if check_3pid_auth is not None: if check_3pid_auth is not None:
@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
get_username_for_registration, get_username_for_registration,
) )
if get_displayname_for_registration is not None:
self.get_displayname_for_registration_callbacks.append(
get_displayname_for_registration,
)
if is_3pid_allowed is not None: if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed) self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
return None return None
async def get_displayname_for_registration(
self,
uia_results: JsonDict,
params: JsonDict,
) -> Optional[str]:
"""Defines the display name to use when registering the user, using the
credentials and parameters provided during the UIA flow.
Stops at the first callback that returns a tuple containing at least one string.
Args:
uia_results: The credentials provided during the UIA flow.
params: The parameters provided by the registration request.
Returns:
A tuple which first element is the display name, and the second is an MXC URL
to the user's avatar.
"""
for callback in self.get_displayname_for_registration_callbacks:
try:
res = await callback(uia_results, params)
if isinstance(res, str):
return res
elif res is not None:
# mypy complains that this line is unreachable because it assumes the
# data returned by the module fits the expected type. We just want
# to make sure this is the case.
logger.warning( # type: ignore[unreachable]
"Ignoring non-string value returned by"
" get_displayname_for_registration callback %s: %s",
callback,
res,
)
except Exception as e:
logger.error(
"Module raised an exception in get_displayname_for_registration: %s",
e,
)
raise SynapseError(code=500, msg="Internal Server Error")
return None
async def is_3pid_allowed( async def is_3pid_allowed(
self, self,
medium: str, medium: str,

View File

@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import ( from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK, CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK, CHECK_AUTH_CALLBACK,
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK, IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK, ON_LOGGED_OUT_CALLBACK,
@ -317,6 +318,9 @@ class ModuleApi:
get_username_for_registration: Optional[ get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None, ] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None: ) -> None:
"""Registers callbacks for password auth provider capabilities. """Registers callbacks for password auth provider capabilities.
@ -328,6 +332,7 @@ class ModuleApi:
is_3pid_allowed=is_3pid_allowed, is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers, auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration, get_username_for_registration=get_username_for_registration,
get_displayname_for_registration=get_displayname_for_registration,
) )
def register_background_update_controller_callbacks( def register_background_update_controller_callbacks(

View File

@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
session_id session_id
) )
display_name = await (
self.password_auth_provider.get_displayname_for_registration(
auth_result, params
)
)
registered_user_id = await self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password_hash=password_hash, password_hash=password_hash,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
threepid=threepid, threepid=threepid,
default_display_name=display_name,
address=client_addr, address=client_addr,
user_agent_ips=entries, user_agent_ips=entries,
) )

View File

@ -84,7 +84,7 @@ class CustomAuthProvider:
def __init__(self, config, api: ModuleApi): def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks( api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
) )
def check_auth(self, *args): def check_auth(self, *args):
@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
auth_checkers={ auth_checkers={
("test.login_type", ("test_field",)): self.check_auth, ("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth,
}, }
) )
pass pass
@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
account.register_servlets, account.register_servlets,
] ]
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
def setUp(self): def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate # we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback can define the username """Tests that the get_username_for_registration callback can define the username
of a user when registering. of a user when registering.
""" """
self._setup_get_username_for_registration() self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)
username = "rin" username = "rin"
channel = self.make_request( channel = self.make_request(
@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback is only called at the """Tests that the get_username_for_registration callback is only called at the
end of the UIA flow. end of the UIA flow.
""" """
m = self._setup_get_username_for_registration() m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)
# Initiate the UIA flow.
username = "rin" username = "rin"
channel = self.make_request( res = self._do_uia_assert_mock_not_called(username, m)
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)
# Check that the callback hasn't been called yet. mxid = res["user_id"]
m.assert_not_called()
# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
# Check that the callback has been called. # Check that the callback has been called.
@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False) self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True) self._test_3pid_allowed("kitay", True)
def test_displayname(self):
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)
username = "rin"
channel = self.make_request(
"POST",
"/register",
{
"username": username,
"password": "bar",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(channel.code, 200)
# Our callback takes the username and appends "-foo" to it, check that's what we
# have.
user_id = UserID.from_string(channel.json_body["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)
self.assertEqual(display_name, username + "-foo")
def test_displayname_uia(self):
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)
username = "rin"
res = self._do_uia_assert_mock_not_called(username, m)
user_id = UserID.from_string(res["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)
self.assertEqual(display_name, username + "-foo")
# Check that the callback has been called.
m.assert_called_once()
def _test_3pid_allowed(self, username: str, registration: bool): def _test_3pid_allowed(self, username: str, registration: bool):
"""Tests that the "is_3pid_allowed" module callback is called correctly, using """Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments. either /register or /account URLs depending on the arguments.
@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "bar@test.com", registration) m.assert_called_once_with("email", "bar@test.com", registration)
def _setup_get_username_for_registration(self) -> Mock: def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the """Registers either a get_username_for_registration callback or a
username the client is trying to register. get_displayname_for_registration callback that appends "-foo" to the username the
client is trying to register.
""" """
async def get_username_for_registration(uia_results, params): async def callback(uia_results, params):
self.assertIn(LoginType.DUMMY, uia_results) self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"] username = params["username"]
return username + "-foo" return username + "-foo"
m = Mock(side_effect=get_username_for_registration) m = Mock(side_effect=callback)
password_auth_provider = self.hs.get_password_auth_provider() password_auth_provider = self.hs.get_password_auth_provider()
password_auth_provider.get_username_for_registration_callbacks.append(m) getattr(password_auth_provider, callback_name + "_callbacks").append(m)
return m return m
def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
# Initiate the UIA flow.
channel = self.make_request(
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)
# Check that the callback hasn't been called yet.
m.assert_not_called()
# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict: def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)