Add missing type hints to tests.handlers. (#14680)

And do not allow untyped defs in tests.handlers.
This commit is contained in:
Patrick Cloke 2022-12-16 06:53:01 -05:00 committed by GitHub
parent 54c012c5a8
commit 652d1669c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 527 additions and 378 deletions

View file

@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
from typing import Any, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass
def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass
def check_password(self, *args):
def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass
def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass
def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}
def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@ -75,15 +76,15 @@ class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass
def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)
def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type."""
@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass
def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass
def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
as well as a password login"""
@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass
def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
}
)
def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
def check_pass(self, *args):
def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
def setUp(self):
def setUp(self) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
def test_password_only_auth_progiver_login_legacy(self) -> None:
self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self):
def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self):
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self):
def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self):
def test_local_user_fallback_login_legacy(self) -> None:
self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self):
def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self):
def test_local_user_fallback_ui_auth_legacy(self) -> None:
self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self):
def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login_legacy(self):
def test_no_local_user_fallback_login_legacy(self) -> None:
self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self):
def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth_legacy(self):
def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self):
def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled_legacy(self):
def test_password_auth_disabled_legacy(self) -> None:
self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self):
def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self):
def test_custom_auth_provider_login_legacy(self) -> None:
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
def test_custom_auth_provider_login(self) -> None:
self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self):
def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self):
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
def test_custom_auth_provider_ui_auth(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self):
def custom_auth_provider_ui_auth_test_body(self) -> None:
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self):
def test_custom_auth_provider_callback_legacy(self) -> None:
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
def test_custom_auth_provider_callback(self) -> None:
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable(
@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self):
def test_custom_auth_password_disabled_legacy(self) -> None:
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
def test_custom_auth_password_disabled(self) -> None:
self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self):
def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self):
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self):
def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
def test_password_custom_auth_password_disabled_login(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self):
def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self):
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self):
def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self):
def test_custom_auth_no_local_user_fallback(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self):
def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self):
def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
async def on_logged_out(user_id, device_id, access_token):
async def on_logged_out(
user_id: str, device_id: Optional[str], access_token: str
) -> None:
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once()
self.assertTrue(self.called)
def test_username(self):
def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
def test_username_uia(self):
def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self):
def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
def test_displayname(self):
def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(display_name, username + "-foo")
def test_displayname_uia(self):
def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
def _test_3pid_allowed(self, username: str, registration: bool):
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
client is trying to register.
"""
async def callback(uia_results, params):
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)