Add type hints to tests/rest/client (#12066)

This commit is contained in:
Dirk Klimpel 2022-02-23 14:33:19 +01:00 committed by GitHub
parent 5b2b36809f
commit 64c73c6ac8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 119 deletions

1
changelog.d/12066.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View File

@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless
class DummyRecaptchaChecker(UserInteractiveAuthChecker): class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs): def __init__(self, hs: HomeServer) -> None:
super().__init__(hs) super().__init__(hs)
self.recaptcha_attempts = [] self.recaptcha_attempts: List[Tuple[dict, str]] = []
def check_auth(self, authdict, clientip): def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip)) self.recaptcha_attempts.append((authdict, clientip))
return succeed(True) return succeed(True)
@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
] ]
hijack_auth = False hijack_auth = False
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config) hs = self.setup_test_homeserver(config=config)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.recaptcha_checker = DummyRecaptchaChecker(hs) self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler() auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(attempts), 1) self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a") self.assertEqual(attempts[0][0]["response"], "a")
def test_fallback_captcha(self): def test_fallback_captcha(self) -> None:
"""Ensure that fallback auth via a captcha works.""" """Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec # Returns a 401 as per the spec
channel = self.register( channel = self.register(
@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
# We're given a registered user. # We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test") self.assertEqual(channel.json_body["user_id"], "@user:test")
def test_complete_operation_unknown_session(self): def test_complete_operation_unknown_session(self) -> None:
""" """
Attempting to mark an invalid session as complete should error. Attempting to mark an invalid session as complete should error.
""" """
@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets, register.register_servlets,
] ]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
return config return config
def create_resource_dict(self): def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict() resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs)) resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict return resource_dict
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass" self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass) self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1" self.device_id = "dev1"
@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
return channel return channel
def test_ui_auth(self): def test_ui_auth(self) -> None:
""" """
Test user interactive authentication outside of registration. Test user interactive authentication outside of registration.
""" """
@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
}, },
) )
def test_grandfathered_identifier(self): def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict """Check behaviour without "identifier" dict
Synapse used to require clients to submit a "user" field for m.login.password Synapse used to require clients to submit a "user" field for m.login.password
@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
}, },
) )
def test_can_change_body(self): def test_can_change_body(self) -> None:
""" """
The client dict can be modified during the user interactive authentication session. The client dict can be modified during the user interactive authentication session.
@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
}, },
) )
def test_cannot_change_uri(self): def test_cannot_change_uri(self) -> None:
""" """
The initial requested URI cannot be modified during the user interactive authentication session. The initial requested URI cannot be modified during the user interactive authentication session.
""" """
@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
) )
@unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
def test_can_reuse_session(self): def test_can_reuse_session(self) -> None:
""" """
The session can be reused if configured. The session can be reused if configured.
@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG}) @override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_via_sso(self): def test_ui_auth_via_sso(self) -> None:
"""Test a successful UI Auth flow via SSO """Test a successful UI Auth flow via SSO
This includes: This includes:
@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG}) @override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self): def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username") login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"] user_tok = login_resp["access_token"]
device_id = login_resp["device_id"] device_id = login_resp["device_id"]
@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
def test_does_not_offer_sso_for_password_user(self): def test_does_not_offer_sso_for_password_user(self) -> None:
channel = self.delete_device( channel = self.delete_device(
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
) )
@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG}) @override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self): def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows""" """A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user) self.assertEqual(login_resp["user_id"], self.user)
@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG}) @override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self): def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error""" """If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in # log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
] ]
hijack_auth = False hijack_auth = False
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass" self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass) self.user = self.register_user("test", self.user_pass)
@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
{"refresh_token": refresh_token}, {"refresh_token": refresh_token},
) )
def is_access_token_valid(self, access_token) -> bool: def is_access_token_valid(self, access_token: str) -> bool:
""" """
Checks whether an access token is valid, returning whether it is or not. Checks whether an access token is valid, returning whether it is or not.
""" """
@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
return code == HTTPStatus.OK return code == HTTPStatus.OK
def test_login_issue_refresh_token(self): def test_login_issue_refresh_token(self) -> None:
""" """
A login response should include a refresh_token only if asked. A login response should include a refresh_token only if asked.
""" """
@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body)
def test_register_issue_refresh_token(self): def test_register_issue_refresh_token(self) -> None:
""" """
A register response should include a refresh_token only if asked. A register response should include a refresh_token only if asked.
""" """
@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body)
def test_token_refresh(self): def test_token_refresh(self) -> None:
""" """
A refresh token can be used to issue a new access token. A refresh token can be used to issue a new access token.
""" """
@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
) )
@override_config({"refreshable_access_token_lifetime": "1m"}) @override_config({"refreshable_access_token_lifetime": "1m"})
def test_refreshable_access_token_expiration(self): def test_refreshable_access_token_expiration(self) -> None:
""" """
The access token should have some time as specified in the config. The access token should have some time as specified in the config.
""" """
@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"nonrefreshable_access_token_lifetime": "10m", "nonrefreshable_access_token_lifetime": "10m",
} }
) )
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
self,
) -> None:
""" """
Tests that the expiry times for refreshable and non-refreshable access Tests that the expiry times for refreshable and non-refreshable access
tokens can be different. tokens can be different.
@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
@override_config( @override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
) )
def test_refresh_token_expiry(self): def test_refresh_token_expiry(self) -> None:
""" """
The refresh token can be configured to have a limited lifetime. The refresh token can be configured to have a limited lifetime.
When that lifetime has ended, the refresh token can no longer be used to When that lifetime has ended, the refresh token can no longer be used to
@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"session_lifetime": "3m", "session_lifetime": "3m",
} }
) )
def test_ultimate_session_expiry(self): def test_ultimate_session_expiry(self) -> None:
""" """
The session can be configured to have an ultimate, limited lifetime. The session can be configured to have an ultimate, limited lifetime.
""" """
@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
) )
def test_refresh_token_invalidation(self): def test_refresh_token_invalidation(self) -> None:
"""Refresh tokens are invalidated after first use of the next token. """Refresh tokens are invalidated after first use of the next token.
A refresh token is considered invalid if: A refresh token is considered invalid if:
@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
) )
def test_many_token_refresh(self): def test_many_token_refresh(self) -> None:
""" """
If a refresh is performed many times during a session, there shouldn't be If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time. extra 'cruft' built up over time.

View File

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.rest.client import capabilities, login from synapse.rest.client import capabilities, login
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = b"/capabilities" self.url = b"/capabilities"
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.config = hs.config self.config = hs.config
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.localpart = "user" self.localpart = "user"
self.password = "pass" self.password = "pass"
self.user = self.register_user(self.localpart, self.password) self.user = self.register_user(self.localpart, self.password)
def test_check_auth_required(self): def test_check_auth_required(self) -> None:
channel = self.make_request("GET", self.url) channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self): def test_get_room_version_capabilities(self) -> None:
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token) channel = self.make_request("GET", self.url, access_token=access_token)
@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"], capabilities["m.room_versions"]["default"],
) )
def test_get_change_password_capabilities_password_login(self): def test_get_change_password_capabilities_password_login(self) -> None:
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token) channel = self.make_request("GET", self.url, access_token=access_token)
@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.change_password"]["enabled"]) self.assertTrue(capabilities["m.change_password"]["enabled"])
@override_config({"password_config": {"localdb_enabled": False}}) @override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self): def test_get_change_password_capabilities_localdb_disabled(self) -> None:
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertFalse(capabilities["m.change_password"]["enabled"]) self.assertFalse(capabilities["m.change_password"]["enabled"])
@override_config({"password_config": {"enabled": False}}) @override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self): def test_get_change_password_capabilities_password_disabled(self) -> None:
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"]) self.assertFalse(capabilities["m.change_password"]["enabled"])
def test_get_change_users_attributes_capabilities(self): def test_get_change_users_attributes_capabilities(self) -> None:
"""Test that server returns capabilities by default.""" """Test that server returns capabilities by default."""
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
@override_config({"enable_set_displayname": False}) @override_config({"enable_set_displayname": False})
def test_get_set_displayname_capabilities_displayname_disabled(self): def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
"""Test if set displayname is disabled that the server responds it.""" """Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertFalse(capabilities["m.set_displayname"]["enabled"]) self.assertFalse(capabilities["m.set_displayname"]["enabled"])
@override_config({"enable_set_avatar_url": False}) @override_config({"enable_set_avatar_url": False})
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
"""Test if set avatar_url is disabled that the server responds it.""" """Test if set avatar_url is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
@override_config({"enable_3pid_changes": False}) @override_config({"enable_3pid_changes": False})
def test_get_change_3pid_capabilities_3pid_disabled(self): def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
"""Test if change 3pid is disabled that the server responds it.""" """Test if change 3pid is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password) access_token = self.login(self.localpart, self.password)
@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
@override_config({"experimental_features": {"msc3244_enabled": False}}) @override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self): def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None:
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
) )
def test_get_does_include_msc3244_fields_when_enabled(self): def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None

View File

@ -20,6 +20,7 @@ from urllib.parse import urlencode
import pymacaroons import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import devices, login, logout, register from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.types import create_requester from synapse.types import create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2 from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless from tests.unittest import HomeserverTestCase, override_config, skip_unless
@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True self.hs.config.registration.enable_registration = True
self.hs.config.registration.registrations_require_3pid = [] self.hs.config.registration.registrations_require_3pid = []
@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
} }
) )
def test_POST_ratelimiting_per_address(self): def test_POST_ratelimiting_per_address(self) -> None:
# Create different users so we're sure not to be bothered by the per-user # Create different users so we're sure not to be bothered by the per-user
# ratelimiter. # ratelimiter.
for i in range(0, 6): for i in range(0, 6):
@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
} }
) )
def test_POST_ratelimiting_per_account(self): def test_POST_ratelimiting_per_account(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):
@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
} }
) )
def test_POST_ratelimiting_per_account_failed_attempts(self): def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):
@ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_soft_logout(self): def test_soft_logout(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token # we shouldn't be able to make requests without an access token
@ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False) self.assertEquals(channel.json_body["soft_logout"], False)
def _delete_device(self, access_token, user_id, password, device_id): def _delete_device(
self, access_token: str, user_id: str, password: str, device_id: str
) -> None:
"""Perform the UI-Auth to delete a device""" """Perform the UI-Auth to delete a device"""
channel = self.make_request( channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token b"DELETE", "devices/" + device_id, access_token=access_token
@ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.code, 200, channel.result) self.assertEquals(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self): def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
# log in as normal # log in as normal
@ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
self,
) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
# log in as normal # log in as normal
@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs)) d.update(build_synapse_client_resource_tree(self.hs))
return d return d
def test_get_login_flows(self): def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows""" """GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_multi_sso_redirect(self): def test_multi_sso_redirect(self) -> None:
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None) channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML # hitting that picker should give us some HTML
channel = self.make_request("GET", uri) channel = self.make_request("GET", uri)
@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self): def test_multi_sso_redirect_to_cas(self) -> None:
"""If CAS is chosen, should redirect to the CAS server""" """If CAS is chosen, should redirect to the CAS server"""
channel = self.make_request( channel = self.make_request(
@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
service_uri_params = urllib.parse.parse_qs(service_uri_query) service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_saml(self): def test_multi_sso_redirect_to_saml(self) -> None:
"""If SAML is chosen, should redirect to the SAML server""" """If SAML is chosen, should redirect to the SAML server"""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
relay_state_param = saml_uri_params["RelayState"][0] relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
def test_login_via_oidc(self): def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint""" """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
# pick the default OIDC provider # pick the default OIDC provider
@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test") self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self): def test_multi_sso_redirect_to_unknown(self) -> None:
"""An unknown IdP should cause a 400""" """An unknown IdP should cause a 400"""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self): def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx") channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self): def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc") channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server # it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
"""Send a request to /_matrix/client/r0/login/sso/redirect """Send a request to /_matrix/client/r0/login/sso/redirect
... possibly specifying an IDP provider ... possibly specifying an IDP provider
@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.base_url = "https://matrix.goodserver.com/" self.base_url = "https://matrix.goodserver.com/"
self.redirect_path = "_synapse/client/login/sso/redirect/confirm" self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase):
cas_user_id = "username" cas_user_id = "username"
self.user_id = "@%s:test" % cas_user_id self.user_id = "@%s:test" % cas_user_id
async def get_raw(uri, args): async def get_raw(uri: str, args: Any) -> bytes:
"""Return an example response payload from a call to the `/proxyValidate` """Return an example response payload from a call to the `/proxyValidate`
endpoint of a CAS server, copied from endpoint of a CAS server, copied from
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.deactivate_account_handler = hs.get_deactivate_account_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler()
def test_cas_redirect_confirm(self): def test_cas_redirect_confirm(self) -> None:
"""Tests that the SSO login flow serves a confirmation page before redirecting a """Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL. user to the redirect URL.
""" """
@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase):
} }
} }
) )
def test_cas_redirect_whitelisted(self): def test_cas_redirect_whitelisted(self) -> None:
"""Tests that the SSO login flow serves a redirect to a whitelisted url""" """Tests that the SSO login flow serves a redirect to a whitelisted url"""
self._test_redirect("https://legit-site.com/") self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"}) @override_config({"public_baseurl": "https://example.com"})
def test_cas_redirect_login_fallback(self): def test_cas_redirect_login_fallback(self) -> None:
self._test_redirect("https://example.com/_matrix/static/client/login") self._test_redirect("https://example.com/_matrix/static/client/login")
def _test_redirect(self, redirect_url): def _test_redirect(self, redirect_url: str) -> None:
"""Tests that the SSO login flow serves a redirect for the given redirect URL.""" """Tests that the SSO login flow serves a redirect for the given redirect URL."""
cas_ticket_url = ( cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
def test_deactivated_user(self): def test_deactivated_user(self) -> None:
"""Logging in as a deactivated account should error.""" """Logging in as a deactivated account should error."""
redirect_url = "https://legit-site.com/" redirect_url = "https://legit-site.com/"
@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
"algorithm": jwt_algorithm, "algorithm": jwt_algorithm,
} }
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# If jwt_config has been defined (eg via @override_config), don't replace it. # If jwt_config has been defined (eg via @override_config), don't replace it.
@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
def test_login_jwt_valid_registered(self): def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self): def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"}) channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self): def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret") channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
"JWT validation failed: Signature verification failed", "JWT validation failed: Signature verification failed",
) )
def test_login_jwt_expired(self): def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000}) channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Signature has expired" channel.json_body["error"], "JWT validation failed: Signature has expired"
) )
def test_login_jwt_not_before(self): def test_login_jwt_not_before(self) -> None:
now = int(time.time()) now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
"JWT validation failed: The token is not yet valid (nbf)", "JWT validation failed: The token is not yet valid (nbf)",
) )
def test_login_no_sub(self): def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"}) channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT") self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self): def test_login_iss(self) -> None:
"""Test validating the issuer claim.""" """Test validating the issuer claim."""
# A valid issuer. # A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
'JWT validation failed: Token is missing the "iss" claim', 'JWT validation failed: Token is missing the "iss" claim',
) )
def test_login_iss_no_config(self): def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration.""" """Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self): def test_login_aud(self) -> None:
"""Test validating the audience claim.""" """Test validating the audience claim."""
# A valid audience. # A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
'JWT validation failed: Token is missing the "aud" claim', 'JWT validation failed: Token is missing the "aud" claim',
) )
def test_login_aud_no_config(self): def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration.""" """Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Invalid audience" channel.json_body["error"], "JWT validation failed: Invalid audience"
) )
def test_login_default_sub(self): def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim.""" """Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self): def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim.""" """Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"}) channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self): def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"} params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
] ]
) )
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["jwt_config"] = { config["jwt_config"] = {
"enabled": True, "enabled": True,
@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
def test_login_jwt_valid(self): def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self): def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
register.register_servlets, register.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.service = ApplicationService( self.service = ApplicationService(
@ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.services_cache.append(self.another_service) self.hs.get_datastores().main.services_cache.append(self.another_service)
return self.hs return self.hs
def test_login_appservice_user(self): def test_login_appservice_user(self) -> None:
"""Test that an appservice user can use /login""" """Test that an appservice user can use /login"""
self.register_appservice_user(AS_USER, self.service.token) self.register_appservice_user(AS_USER, self.service.token)
@ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self): def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login""" """Test that the appservice bot can use /login"""
self.register_appservice_user(AS_USER, self.service.token) self.register_appservice_user(AS_USER, self.service.token)
@ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self): def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token""" """Test that non-as users cannot login with the as token"""
self.register_appservice_user(AS_USER, self.service.token) self.register_appservice_user(AS_USER, self.service.token)
@ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self): def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token""" """Test that as users cannot login with wrong as token"""
self.register_appservice_user(AS_USER, self.service.token) self.register_appservice_user(AS_USER, self.service.token)
@ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_no_token(self): def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice """Test that users must provide a token when using the appservice
login method login method
""" """
@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
servlets = [login.register_servlets] servlets = [login.register_servlets]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs)) d.update(build_synapse_client_resource_tree(self.hs))
return d return d
def test_username_picker(self): def test_username_picker(self) -> None:
"""Test the happy path of a username picker flow.""" """Test the happy path of a username picker flow."""
# do the start of the login flow # do the start of the login flow

View File

@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
from typing import List, Optional
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
@ -24,6 +27,9 @@ from synapse.api.constants import (
RelationTypes, RelationTypes,
) )
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.federation.transport.test_knocking import ( from tests.federation.transport.test_knocking import (
@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def test_sync_argless(self): def test_sync_argless(self) -> None:
channel = self.make_request("GET", "/sync") channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def test_sync_filter_labels(self): def test_sync_filter_labels(self) -> None:
"""Test that we can filter by a label.""" """Test that we can filter by a label."""
sync_filter = json.dumps( sync_filter = json.dumps(
{ {
@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
def test_sync_filter_not_labels(self): def test_sync_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label.""" """Test that we can filter by the absence of a label."""
sync_filter = json.dumps( sync_filter = json.dumps(
{ {
@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
events[2]["content"]["body"], "with two wrong labels", events[2] events[2]["content"]["body"], "with two wrong labels", events[2]
) )
def test_sync_filter_labels_not_labels(self): def test_sync_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label.""" """Test that we can filter by both a label and the absence of another label."""
sync_filter = json.dumps( sync_filter = json.dumps(
{ {
@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(len(events), 1, [event["content"] for event in events])
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
def _test_sync_filter_labels(self, sync_filter): def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]:
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
user_id = True user_id = True
hijack_auth = False hijack_auth = False
def test_sync_backwards_typing(self): def test_sync_backwards_typing(self) -> None:
""" """
If the typing serial goes backwards and the typing handler is then reset If the typing serial goes backwards and the typing handler is then reset
(such as when the master restarts and sets the typing serial to 0), we (such as when the master restarts and sets the typing serial to 0), we
@ -298,7 +304,7 @@ class SyncKnockTestCase(
knock.register_servlets, knock.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.url = "/sync?since=%s" self.url = "/sync?since=%s"
self.next_batch = "s0" self.next_batch = "s0"
@ -336,7 +342,7 @@ class SyncKnockTestCase(
) )
@override_config({"experimental_features": {"msc2403_enabled": True}}) @override_config({"experimental_features": {"msc2403_enabled": True}})
def test_knock_room_state(self): def test_knock_room_state(self) -> None:
"""Tests that /sync returns state from a room after knocking on it.""" """Tests that /sync returns state from a room after knocking on it."""
# Knock on a room # Knock on a room
channel = self.make_request( channel = self.make_request(
@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s" self.url = "/sync?since=%s"
self.next_batch = "s0" self.next_batch = "s0"
@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@override_config({"experimental_features": {"msc2285_enabled": True}}) @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_hidden_read_receipts(self): def test_hidden_read_receipts(self) -> None:
# Send a message as the first user # Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok) res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
] ]
) )
def test_read_receipt_with_empty_body( def test_read_receipt_with_empty_body(
self, name, user_agent: str, expected_status_code: int self, name: str, user_agent: str, expected_status_code: int
): ) -> None:
# Send a message as the first user # Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok) res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, expected_status_code) self.assertEqual(channel.code, expected_status_code)
def _get_read_receipt(self): def _get_read_receipt(self) -> Optional[JsonDict]:
"""Syncs and returns the read receipt.""" """Syncs and returns the read receipt."""
# Checks if event is a read receipt # Checks if event is a read receipt
def is_read_receipt(event): def is_read_receipt(event: JsonDict) -> bool:
return event["type"] == "m.receipt" return event["type"] == "m.receipt"
# Sync # Sync
@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
"ephemeral" "ephemeral"
]["events"] ]["events"]
return next(filter(is_read_receipt, ephemeral_events), None) receipt_event = filter(is_read_receipt, ephemeral_events)
return next(receipt_event, None)
class UnreadMessagesTestCase(unittest.HomeserverTestCase): class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
receipts.register_servlets, receipts.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s" self.url = "/sync?since=%s"
self.next_batch = "s0" self.next_batch = "s0"
@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
tok=self.tok, tok=self.tok,
) )
def test_unread_counts(self): def test_unread_counts(self) -> None:
"""Tests that /sync returns the right value for the unread count (MSC2654).""" """Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count. # Check that our own messages don't increase the unread count.
@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
) )
self._check_unread_count(5) self._check_unread_count(5)
def _check_unread_count(self, expected_count: int): def _check_unread_count(self, expected_count: int) -> None:
"""Syncs and compares the unread count with the expected value.""" """Syncs and compares the unread count with the expected value."""
channel = self.make_request( channel = self.make_request(
@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def test_noop_sync_does_not_tightloop(self): def test_noop_sync_does_not_tightloop(self) -> None:
"""If the sync times out, we shouldn't cache the result """If the sync times out, we shouldn't cache the result
Essentially a regression test for #8518. Essentially a regression test for #8518.
@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
devices.register_servlets, devices.register_servlets,
] ]
def test_user_with_no_rooms_receives_self_device_list_updates(self): def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates""" """Tests that a user with no rooms still receives their own device list updates"""
device_id = "TESTDEVICE" device_id = "TESTDEVICE"