mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints to some tests/handlers files. (#12224)
This commit is contained in:
parent
2fcf4b3f6c
commit
5dd949bee6
1
changelog.d/12224.misc
Normal file
1
changelog.d/12224.misc
Normal file
@ -0,0 +1 @@
|
||||
Add type hints to tests files.
|
5
mypy.ini
5
mypy.ini
@ -67,13 +67,8 @@ exclude = (?x)
|
||||
|tests/federation/transport/test_knocking.py
|
||||
|tests/federation/transport/test_server.py
|
||||
|tests/handlers/test_cas.py
|
||||
|tests/handlers/test_directory.py
|
||||
|tests/handlers/test_e2e_keys.py
|
||||
|tests/handlers/test_federation.py
|
||||
|tests/handlers/test_oidc.py
|
||||
|tests/handlers/test_presence.py
|
||||
|tests/handlers/test_profile.py
|
||||
|tests/handlers/test_saml.py
|
||||
|tests/handlers/test_typing.py
|
||||
|tests/http/federation/test_matrix_federation_agent.py
|
||||
|tests/http/federation/test_srv_resolver.py
|
||||
|
@ -12,14 +12,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.rest.client import directory, login, room
|
||||
from synapse.types import RoomAlias, create_requester
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
|
||||
class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests the directory service."""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.mock_federation = Mock()
|
||||
self.mock_registry = Mock()
|
||||
|
||||
self.query_handlers = {}
|
||||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||
|
||||
def register_query_handler(query_type, handler):
|
||||
def register_query_handler(
|
||||
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
) -> None:
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
self.mock_registry.register_query_handler = register_query_handler
|
||||
@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_get_local_association(self):
|
||||
def test_get_local_association(self) -> None:
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.my_room, "!8765qwer:test", ["test"]
|
||||
@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||
|
||||
def test_get_remote_association(self):
|
||||
def test_get_remote_association(self) -> None:
|
||||
self.mock_federation.make_query.return_value = make_awaitable(
|
||||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
||||
)
|
||||
@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self):
|
||||
def test_incoming_fed_query(self) -> None:
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.your_room, "!8765asdf:test", ["test"]
|
||||
@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_directory_handler()
|
||||
|
||||
# Create user
|
||||
@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def test_create_alias_joined_room(self):
|
||||
def test_create_alias_joined_room(self) -> None:
|
||||
"""A user can create an alias for a room they're in."""
|
||||
self.get_success(
|
||||
self.handler.create_association(
|
||||
@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_create_alias_other_room(self):
|
||||
def test_create_alias_other_room(self) -> None:
|
||||
"""A user cannot create an alias for a room they're NOT in."""
|
||||
other_room_id = self.helper.create_room_as(
|
||||
self.admin_user, tok=self.admin_user_tok
|
||||
@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_create_alias_admin(self):
|
||||
def test_create_alias_admin(self) -> None:
|
||||
"""An admin can create an alias for a room they're NOT in."""
|
||||
other_room_id = self.helper.create_room_as(
|
||||
self.test_user, tok=self.test_user_tok
|
||||
@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def _create_alias(self, user):
|
||||
def _create_alias(self, user) -> None:
|
||||
# Create a new alias to this room.
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_delete_alias_not_allowed(self):
|
||||
def test_delete_alias_not_allowed(self) -> None:
|
||||
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
self.get_failure(
|
||||
@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.AuthError,
|
||||
)
|
||||
|
||||
def test_delete_alias_creator(self):
|
||||
def test_delete_alias_creator(self) -> None:
|
||||
"""An alias creator can delete their own alias."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_admin(self):
|
||||
def test_delete_alias_admin(self) -> None:
|
||||
"""A server admin can delete an alias created by another user."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_sufficient_power(self):
|
||||
def test_delete_alias_sufficient_power(self) -> None:
|
||||
"""A user with a sufficient power level should be able to delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
|
||||
@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
return room_alias
|
||||
|
||||
def _set_canonical_alias(self, content):
|
||||
def _set_canonical_alias(self, content) -> None:
|
||||
"""Configure the canonical alias state on the room."""
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_remove_alias(self):
|
||||
def test_remove_alias(self) -> None:
|
||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||
# Set this new alias as the canonical alias for this room
|
||||
self._set_canonical_alias(
|
||||
@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotIn("alias", data["content"])
|
||||
self.assertNotIn("alt_aliases", data["content"])
|
||||
|
||||
def test_remove_other_alias(self):
|
||||
def test_remove_other_alias(self) -> None:
|
||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||
# Create a second alias.
|
||||
other_test_alias = "#test2:test"
|
||||
@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [directory.register_servlets, room.register_servlets]
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
|
||||
# Add custom alias creation rules to the config.
|
||||
@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def test_denied(self):
|
||||
def test_denied(self) -> None:
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
|
||||
def test_allowed(self):
|
||||
def test_allowed(self) -> None:
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def test_denied_during_creation(self):
|
||||
def test_denied_during_creation(self) -> None:
|
||||
"""A room alias that is not allowed should be rejected during creation."""
|
||||
# Invalid room alias.
|
||||
self.helper.create_room_as(
|
||||
@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||
extra_content={"room_alias_name": "foo"},
|
||||
)
|
||||
|
||||
def test_allowed_during_creation(self):
|
||||
def test_allowed_during_creation(self) -> None:
|
||||
"""A valid room alias should be allowed during creation."""
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user_id,
|
||||
@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
data = {"room_alias_name": "unofficial_test"}
|
||||
allowed_localpart = "allowed"
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
|
||||
# Add custom room list publication rules to the config.
|
||||
@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||
|
||||
@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_denied_without_publication_permission(self):
|
||||
def test_denied_without_publication_permission(self) -> None:
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user without permission to publish rooms.
|
||||
@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=403,
|
||||
)
|
||||
|
||||
def test_allowed_when_creating_private_room(self):
|
||||
def test_allowed_when_creating_private_room(self) -> None:
|
||||
"""
|
||||
Try to create a room, register an alias for it, and NOT publish it,
|
||||
as a user without permission to publish rooms.
|
||||
@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
def test_allowed_with_publication_permission(self):
|
||||
def test_allowed_with_publication_permission(self) -> None:
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user WITH permission to publish rooms.
|
||||
@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
def test_denied_publication_with_invalid_alias(self):
|
||||
def test_denied_publication_with_invalid_alias(self) -> None:
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
as a user WITH permission to publish rooms.
|
||||
@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
expect_code=403,
|
||||
)
|
||||
|
||||
def test_can_create_as_private_room_after_rejection(self):
|
||||
def test_can_create_as_private_room_after_rejection(self) -> None:
|
||||
"""
|
||||
After failing to publish a room with an alias as a user without publish permission,
|
||||
retry as the same user, but without publishing the room.
|
||||
@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||
self.test_denied_without_publication_permission()
|
||||
self.test_allowed_when_creating_private_room()
|
||||
|
||||
def test_can_create_with_permission_after_rejection(self):
|
||||
def test_can_create_with_permission_after_rejection(self) -> None:
|
||||
"""
|
||||
After failing to publish a room with an alias as a user without publish permission,
|
||||
retry as someone with permission, using the same alias.
|
||||
@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [directory.register_servlets, room.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def test_disabling_room_list(self):
|
||||
def test_disabling_room_list(self) -> None:
|
||||
self.room_list_handler.enable_room_list_search = True
|
||||
self.directory_handler.enable_room_list_search = True
|
||||
|
||||
|
@ -20,33 +20,37 @@ from parameterized import parameterized
|
||||
from signedjson import key as key, sign as sign
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
return self.setup_test_homeserver(federation_client=mock.Mock())
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_e2e_keys_handler()
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_query_local_devices_no_devices(self):
|
||||
def test_query_local_devices_no_devices(self) -> None:
|
||||
"""If the user has no devices, we expect an empty list."""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
def test_reupload_one_time_keys(self):
|
||||
def test_reupload_one_time_keys(self) -> None:
|
||||
"""we should be able to re-upload the same keys"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {
|
||||
keys: JsonDict = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
||||
)
|
||||
|
||||
def test_change_one_time_keys(self):
|
||||
def test_change_one_time_keys(self) -> None:
|
||||
"""attempts to change one-time-keys should be rejected"""
|
||||
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_claim_one_time_key(self):
|
||||
def test_claim_one_time_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {"alg1:k1": "key1"}
|
||||
@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_fallback_key(self):
|
||||
def test_fallback_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||
)
|
||||
|
||||
def test_replace_master_key(self):
|
||||
def test_replace_master_key(self) -> None:
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||
|
||||
def test_reupload_signatures(self):
|
||||
def test_reupload_signatures(self) -> None:
|
||||
"""re-uploading a signature should not fail"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
||||
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
||||
|
||||
def test_self_signing_key_doesnt_show_up_as_device(self):
|
||||
def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
|
||||
"""signing keys should be hidden when fetching a user's devices"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
keys1 = {
|
||||
@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
def test_upload_signatures(self):
|
||||
def test_upload_signatures(self) -> None:
|
||||
"""should check signatures that are uploaded"""
|
||||
# set up a user with cross-signing keys and a device. This user will
|
||||
# try uploading signatures
|
||||
@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||
)
|
||||
|
||||
def test_query_devices_remote_no_sync(self):
|
||||
def test_query_devices_remote_no_sync(self) -> None:
|
||||
"""Tests that querying keys for a remote user that we don't share a room
|
||||
with returns the cross signing keys correctly.
|
||||
"""
|
||||
@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_query_devices_remote_sync(self):
|
||||
def test_query_devices_remote_sync(self) -> None:
|
||||
"""Tests that querying keys for a remote user that we share a room with,
|
||||
but haven't yet fetched the keys for, returns the cross signing keys
|
||||
correctly.
|
||||
@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
(["device_1", "device_2"],),
|
||||
]
|
||||
)
|
||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
|
||||
"""Test that requests for all of a remote user's devices are cached.
|
||||
|
||||
We do this by asserting that only one call over federation was made, and that
|
||||
@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
local_user_id = "@test:test"
|
||||
remote_user_id = "@test:other"
|
||||
request_body = {"device_keys": {remote_user_id: []}}
|
||||
request_body: JsonDict = {"device_keys": {remote_user_id: []}}
|
||||
|
||||
response_devices = [
|
||||
{
|
||||
|
@ -13,14 +13,18 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.macaroons import get_value_from_macaroon
|
||||
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||
@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
|
||||
}
|
||||
|
||||
|
||||
async def get_json(url):
|
||||
async def get_json(url: str) -> JsonDict:
|
||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||
if url == WELL_KNOWN:
|
||||
# Minimal discovery document, as defined in OpenID.Discovery
|
||||
@ -116,6 +120,8 @@ async def get_json(url):
|
||||
elif url == JWKS_URI:
|
||||
return {"keys": []}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _key_file_path() -> str:
|
||||
"""path to a file containing the private half of a test key"""
|
||||
@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
if not HAS_OIDC:
|
||||
skip = "requires OIDC"
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock(spec=["get_json"])
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
self.http_client.user_agent = b"Synapse Test"
|
||||
@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
sso_handler = hs.get_sso_handler()
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
sso_handler.render_error = self.render_error
|
||||
sso_handler.render_error = self.render_error # type: ignore[assignment]
|
||||
|
||||
# Reduce the number of attempts when generating MXIDs.
|
||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
return args
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_config(self):
|
||||
def test_config(self) -> None:
|
||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||
def test_discovery(self):
|
||||
def test_discovery(self) -> None:
|
||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||
# This would throw if some metadata were invalid
|
||||
metadata = self.get_success(self.provider.load_metadata())
|
||||
@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||
def test_no_discovery(self):
|
||||
def test_no_discovery(self) -> None:
|
||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||
self.get_success(self.provider.load_metadata())
|
||||
self.http_client.get_json.assert_not_called()
|
||||
|
||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||
def test_load_jwks(self):
|
||||
def test_load_jwks(self) -> None:
|
||||
"""JWKS loading is done once (then cached) if used."""
|
||||
jwks = self.get_success(self.provider.load_jwks())
|
||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||
@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_validate_config(self):
|
||||
def test_validate_config(self) -> None:
|
||||
"""Provider metadatas are extensively validated."""
|
||||
h = self.provider
|
||||
|
||||
@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
force_load_metadata()
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||
def test_skip_verification(self):
|
||||
def test_skip_verification(self) -> None:
|
||||
"""Provider metadata validation can be disabled by config."""
|
||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||
# This should not throw
|
||||
get_awaitable_result(self.provider.load_metadata())
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_redirect_request(self):
|
||||
def test_redirect_request(self) -> None:
|
||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||
req = Mock(spec=["cookies"])
|
||||
req.cookies = []
|
||||
@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(redirect, "http://client/redirect")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback_error(self):
|
||||
def test_callback_error(self) -> None:
|
||||
"""Errors from the provider returned in the callback are displayed."""
|
||||
request = Mock(args={})
|
||||
request.args[b"error"] = [b"invalid_client"]
|
||||
@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertRenderedError("invalid_client", "some description")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback(self):
|
||||
def test_callback(self) -> None:
|
||||
"""Code callback works and display errors if something went wrong.
|
||||
|
||||
A lot of scenarios are tested here:
|
||||
@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": username,
|
||||
}
|
||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertRenderedError("mapping_error")
|
||||
|
||||
# Handle ID token errors
|
||||
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
||||
self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_token")
|
||||
|
||||
@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"type": "bearer",
|
||||
"access_token": "access_token",
|
||||
}
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
id_token = {
|
||||
"sid": "abcdefgh",
|
||||
}
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.provider._fetch_userinfo.reset_mock()
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle userinfo fetching error
|
||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("fetch_error")
|
||||
|
||||
# Handle code exchange failure
|
||||
from synapse.handlers.oidc import OidcError
|
||||
|
||||
self.provider._exchange_code = simple_async_mock(
|
||||
self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
|
||||
raises=OidcError("invalid_request")
|
||||
)
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
self.assertRenderedError("invalid_request")
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_callback_session(self):
|
||||
def test_callback_session(self) -> None:
|
||||
"""The callback verifies the session presence and validity"""
|
||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||
|
||||
@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
@override_config(
|
||||
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||
)
|
||||
def test_exchange_code(self):
|
||||
def test_exchange_code(self) -> None:
|
||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||
token = {"type": "bearer"}
|
||||
token_json = json.dumps(token).encode("utf-8")
|
||||
@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_exchange_code_jwt_key(self):
|
||||
def test_exchange_code_jwt_key(self) -> None:
|
||||
"""Test that code exchange works with a JWK client secret."""
|
||||
from authlib.jose import jwt
|
||||
|
||||
@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_exchange_code_no_auth(self):
|
||||
def test_exchange_code_no_auth(self) -> None:
|
||||
"""Test that code exchange works with no client secret."""
|
||||
token = {"type": "bearer"}
|
||||
self.http_client.request = simple_async_mock(
|
||||
@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_extra_attributes(self):
|
||||
def test_extra_attributes(self) -> None:
|
||||
"""
|
||||
Login while using a mapping provider that implements get_extra_attributes.
|
||||
"""
|
||||
@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
"username": "foo",
|
||||
"phone": "1234567",
|
||||
}
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_map_userinfo_to_user(self):
|
||||
def test_map_userinfo_to_user(self) -> None:
|
||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
userinfo = {
|
||||
userinfo: dict = {
|
||||
"sub": "test_user",
|
||||
"username": "test_user",
|
||||
}
|
||||
@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||
def test_map_userinfo_to_existing_user(self):
|
||||
def test_map_userinfo_to_existing_user(self) -> None:
|
||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||
store = self.hs.get_datastores().main
|
||||
user = UserID.from_string("@test_user:test")
|
||||
@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_map_userinfo_to_invalid_localpart(self):
|
||||
def test_map_userinfo_to_invalid_localpart(self) -> None:
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
self.get_success(
|
||||
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
||||
@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_map_userinfo_to_user_retries(self):
|
||||
def test_map_userinfo_to_user_retries(self) -> None:
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||
def test_empty_localpart(self):
|
||||
def test_empty_localpart(self) -> None:
|
||||
"""Attempts to map onto an empty localpart should be rejected."""
|
||||
userinfo = {
|
||||
"sub": "tester",
|
||||
@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_null_localpart(self):
|
||||
def test_null_localpart(self) -> None:
|
||||
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
||||
userinfo = {
|
||||
"sub": "tester",
|
||||
@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements(self):
|
||||
def test_attribute_requirements(self) -> None:
|
||||
"""The required attributes must be met from the OIDC userinfo response."""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements_contains(self):
|
||||
def test_attribute_requirements_contains(self) -> None:
|
||||
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements_mismatch(self):
|
||||
def test_attribute_requirements_mismatch(self) -> None:
|
||||
"""
|
||||
Test that auth fails if attributes exist but don't match,
|
||||
or are non-string values.
|
||||
@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
# userinfo with "test": "not_foobar" attribute should fail
|
||||
userinfo = {
|
||||
userinfo: dict = {
|
||||
"sub": "tester",
|
||||
"username": "tester",
|
||||
"test": "not_foobar",
|
||||
@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
|
||||
|
||||
handler = hs.get_oidc_handler()
|
||||
provider = handler._providers["oidc"]
|
||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
|
||||
provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||
|
||||
state = "state"
|
||||
session = handler._token_generator.generate_oidc_session_token(
|
||||
|
@ -11,14 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.rest import admin
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [admin.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.mock_federation = Mock()
|
||||
self.mock_registry = Mock()
|
||||
|
||||
self.query_handlers = {}
|
||||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||
|
||||
def register_query_handler(query_type, handler):
|
||||
def register_query_handler(
|
||||
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
) -> None:
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
self.mock_registry.register_query_handler = register_query_handler
|
||||
@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
self.frank = UserID.from_string("@1234abcd:test")
|
||||
@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.handler = hs.get_profile_handler()
|
||||
|
||||
def test_get_my_name(self):
|
||||
def test_get_my_name(self) -> None:
|
||||
self.get_success(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual("Frank", displayname)
|
||||
|
||||
def test_set_my_name(self):
|
||||
def test_set_my_name(self) -> None:
|
||||
self.get_success(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self):
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
self.hs.config.registration.enable_set_displayname = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_set_my_name_noauth(self):
|
||||
def test_set_my_name_noauth(self) -> None:
|
||||
self.get_failure(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||
@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_get_other_name(self):
|
||||
def test_get_other_name(self) -> None:
|
||||
self.mock_federation.make_query.return_value = make_awaitable(
|
||||
{"displayname": "Alice"}
|
||||
)
|
||||
@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self):
|
||||
def test_incoming_fed_query(self) -> None:
|
||||
self.get_success(self.store.create_profile("caroline"))
|
||||
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
||||
|
||||
@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual({"displayname": "Caroline"}, response)
|
||||
|
||||
def test_get_my_avatar(self):
|
||||
def test_get_my_avatar(self) -> None:
|
||||
self.get_success(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual("http://my.server/me.png", avatar_url)
|
||||
|
||||
def test_set_my_avatar(self):
|
||||
def test_set_my_avatar(self) -> None:
|
||||
self.get_success(
|
||||
self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
)
|
||||
|
||||
def test_set_my_avatar_if_disabled(self):
|
||||
def test_set_my_avatar_if_disabled(self) -> None:
|
||||
self.hs.config.registration.enable_set_avatar_url = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
def test_avatar_constraints_no_config(self):
|
||||
def test_avatar_constraints_no_config(self) -> None:
|
||||
"""Tests that the method to check an avatar against configured constraints skips
|
||||
all of its check if no constraint is configured.
|
||||
"""
|
||||
@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertTrue(res)
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_constraints_missing(self):
|
||||
def test_avatar_constraints_missing(self) -> None:
|
||||
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
||||
be found.
|
||||
"""
|
||||
@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertFalse(res)
|
||||
|
||||
@unittest.override_config({"max_avatar_size": 50})
|
||||
def test_avatar_constraints_file_size(self):
|
||||
def test_avatar_constraints_file_size(self) -> None:
|
||||
"""Tests that a file that's above the allowed file size is forbidden but one
|
||||
that's below it is allowed.
|
||||
"""
|
||||
@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.assertFalse(res)
|
||||
|
||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||
def test_avatar_constraint_mime_type(self):
|
||||
def test_avatar_constraint_mime_type(self) -> None:
|
||||
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
||||
an authorised content type is allowed.
|
||||
"""
|
||||
|
@ -12,12 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.test_utils import simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
|
||||
|
||||
|
||||
class SamlHandlerTestCase(HomeserverTestCase):
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
saml_config = {
|
||||
saml_config: Dict[str, Any] = {
|
||||
"sp_config": {"metadata": {}},
|
||||
# Disable grandfathering.
|
||||
"grandfathered_mxid_source_attribute": None,
|
||||
@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
self.handler = hs.get_saml_handler()
|
||||
@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
elif not has_xmlsec1:
|
||||
skip = "Requires xmlsec1"
|
||||
|
||||
def test_map_saml_response_to_user(self):
|
||||
def test_map_saml_response_to_user(self) -> None:
|
||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||
|
||||
# stub out the auth handler
|
||||
@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||
def test_map_saml_response_to_existing_user(self):
|
||||
def test_map_saml_response_to_existing_user(self) -> None:
|
||||
"""Existing users can log in with SAML account."""
|
||||
store = self.hs.get_datastores().main
|
||||
self.get_success(
|
||||
@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
auth_provider_session_id=None,
|
||||
)
|
||||
|
||||
def test_map_saml_response_to_invalid_localpart(self):
|
||||
def test_map_saml_response_to_invalid_localpart(self) -> None:
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
|
||||
# stub out the auth handler
|
||||
@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_not_called()
|
||||
|
||||
def test_map_saml_response_to_user_retries(self):
|
||||
def test_map_saml_response_to_user_retries(self) -> None:
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
|
||||
# stub out the auth handler and error renderer
|
||||
@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_map_saml_response_redirect(self):
|
||||
def test_map_saml_response_redirect(self) -> None:
|
||||
"""Test a mapping provider that raises a RedirectException"""
|
||||
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||
@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_attribute_requirements(self):
|
||||
def test_attribute_requirements(self) -> None:
|
||||
"""The required attributes must be met from the SAML response."""
|
||||
|
||||
# stub out the auth handler
|
||||
|
Loading…
Reference in New Issue
Block a user