synapse-product/tests/handlers/test_password_providers.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

877 lines
34 KiB
Python
Raw Normal View History

# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the password_auth_provider interface"""
from typing import Any, Type, Union
2021-04-09 13:44:38 -04:00
from unittest.mock import Mock
from twisted.internet import defer
import synapse
from synapse.api.constants import LoginType
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
ADDITIONAL_LOGIN_FLOWS = [
{"type": "m.login.application_service"},
{"type": "uk.half-shot.msc2778.login.application_service"},
]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
mock_password_provider = Mock()
class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def check_password(self, *args):
return mock_password_provider.check_password(*args)
class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
)
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class PasswordCustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type.
as well as a password login"""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
},
)
pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
def check_pass(self, *args):
return mock_password_provider.check_password(*args)
def legacy_providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given modules"""
return {
"modules": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
logout.register_servlets,
register.register_servlets,
]
def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
# Load the modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_password_auth_providers(hs)
return hs
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
)
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self):
self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
module_api = self.hs.get_module_api()
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Make another request providing the UI auth flow.
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self):
self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self):
self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Wrong password
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "xxx"
)
mock_password_provider.reset_mock()
# Right password
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login_legacy(self):
self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth_legacy(self):
self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
self.login("localuser", "p", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled_legacy(self):
self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self):
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS,
)
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
mock_password_provider.reset_mock()
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self):
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
# missing param
body = {
"auth": {
"type": "test.login_type",
"identifier": {"type": "m.id.user", "user": "localuser"},
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
# use it...
self.assertIn("Missing parameters", channel.json_body["error"])
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self):
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
# check the args to the callback
callback.assert_called_once()
call_args, call_kwargs = callback.call_args
# should be one positional arg
self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["user_id"], "@user:bz")
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self):
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
cause an exception.
"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self):
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected. In particular, "password" should *not*
# be present.
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
mock_password_provider.reset_mock()
# check that auth with password is rejected
body = {
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "localuser"},
"password": "localpass",
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
self.assertEqual(
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
body["auth"]["type"] = "test.login_type"
body["auth"]["test_field"] = "x"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self):
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self):
self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
async def on_logged_out(user_id, device_id, access_token):
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
on_logged_out
)
channel = self.make_request(
"POST",
"/_matrix/client/v3/logout",
{},
access_token=tok,
)
self.assertEqual(channel.code, 200)
on_logged_out.assert_called_once()
self.assertTrue(self.called)
def test_username(self):
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
self._setup_get_username_for_registration()
username = "rin"
channel = self.make_request(
"POST",
"/register",
{
"username": username,
"password": "bar",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(channel.code, 200)
# Our callback takes the username and appends "-foo" to it, check that's what we
# have.
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
def test_username_uia(self):
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_username_for_registration()
# Initiate the UIA flow.
username = "rin"
channel = self.make_request(
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)
# Check that the callback hasn't been called yet.
m.assert_not_called()
# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
# Check that the callback has been called.
m.assert_called_once()
def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register.
"""
async def get_username_for_registration(uia_results, params):
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
m = Mock(side_effect=get_username_for_registration)
password_auth_provider = self.hs.get_password_auth_provider()
password_auth_provider.get_username_for_registration_callbacks.append(m)
return m
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
return channel.json_body["session"]
def _authed_delete_device(
self,
access_token: str,
device_id: str,
session: str,
user_id: str,
password: str,
) -> FakeChannel:
"""Make a delete device request, authenticating with the given uid/password"""
return self._delete_device(
access_token,
device_id,
{
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
"password": password,
"session": session,
},
},
)
def _delete_device(
self,
access_token: str,
device: str,
body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel