Add type hints to tests/rest/client (#12084)

This commit is contained in:
Dirk Klimpel 2022-02-28 18:47:37 +01:00 committed by GitHub
parent 6c0b44a3d7
commit 1901cb1d4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 198 additions and 138 deletions

1
changelog.d/12084.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View File

@ -84,7 +84,6 @@ exclude = (?x)
|tests/rest/client/test_third_party_rules.py
|tests/rest/client/test_transactions.py
|tests/rest/client/test_typing.py
|tests/rest/client/utils.py
|tests/rest/key/v2/test_remote_key_resource.py
|tests/rest/media/v1/test_base.py
|tests/rest/media/v1/test_media_storage.py
@ -253,7 +252,7 @@ disallow_untyped_defs = True
[mypy-tests.rest.admin.*]
disallow_untyped_defs = True
[mypy-tests.rest.client.test_directory]
[mypy-tests.rest.client.*]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]

View File

@ -13,12 +13,16 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
from typing import Any, Dict
from typing import Any, Dict, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
return self.hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
self.other = self.register_user("other", "pass", displayname="Bob")
def test_get_displayname(self):
def test_get_displayname(self) -> None:
res = self._get_displayname()
self.assertEqual(res, "owner")
def test_set_displayname(self):
def test_set_displayname(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
@ -57,7 +61,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname()
self.assertEqual(res, "test")
def test_set_displayname_noauth(self):
def test_set_displayname_noauth(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
@ -65,7 +69,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_displayname_too_long(self):
def test_set_displayname_too_long(self) -> None:
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
"PUT",
@ -78,11 +82,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname()
self.assertEqual(res, "owner")
def test_get_displayname_other(self):
def test_get_displayname_other(self) -> None:
res = self._get_displayname(self.other)
self.assertEqual(res, "Bob")
def test_set_displayname_other(self):
def test_set_displayname_other(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.other,),
@ -91,11 +95,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
def test_get_avatar_url(self):
def test_get_avatar_url(self) -> None:
res = self._get_avatar_url()
self.assertIsNone(res)
def test_set_avatar_url(self):
def test_set_avatar_url(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
@ -107,7 +111,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_avatar_url()
self.assertEqual(res, "http://my.server/pic.gif")
def test_set_avatar_url_noauth(self):
def test_set_avatar_url_noauth(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.owner,),
@ -115,7 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 401, channel.result)
def test_set_avatar_url_too_long(self):
def test_set_avatar_url_too_long(self) -> None:
"""Attempts to set a stupid avatar_url should get a 400"""
channel = self.make_request(
"PUT",
@ -128,11 +132,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_avatar_url()
self.assertIsNone(res)
def test_get_avatar_url_other(self):
def test_get_avatar_url_other(self) -> None:
res = self._get_avatar_url(self.other)
self.assertIsNone(res)
def test_set_avatar_url_other(self):
def test_set_avatar_url_other(self) -> None:
channel = self.make_request(
"PUT",
"/profile/%s/avatar_url" % (self.other,),
@ -141,14 +145,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
def _get_displayname(self, name=None):
def _get_displayname(self, name: Optional[str] = None) -> str:
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
def _get_avatar_url(self, name=None):
def _get_avatar_url(self, name: Optional[str] = None) -> str:
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
@ -156,7 +160,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_size_limit_global(self):
def test_avatar_size_limit_global(self) -> None:
"""Tests that the maximum size limit for avatars is enforced when updating a
global profile.
"""
@ -187,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_size_limit_per_room(self):
def test_avatar_size_limit_per_room(self) -> None:
"""Tests that the maximum size limit for avatars is enforced when updating a
per-room profile.
"""
@ -220,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_allowed_mime_type_global(self):
def test_avatar_allowed_mime_type_global(self) -> None:
"""Tests that the MIME type whitelist for avatars is enforced when updating a
global profile.
"""
@ -251,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_allowed_mime_type_per_room(self):
def test_avatar_allowed_mime_type_per_room(self) -> None:
"""Tests that the MIME type whitelist for avatars is enforced when updating a
per-room profile.
"""
@ -283,7 +287,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
Args:
@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_auth_for_profile_requests"] = True
config["limit_profile_requests_to_users_who_share_rooms"] = True
@ -325,7 +328,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# User owning the requested profile.
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
@ -337,22 +340,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
def test_no_auth(self):
def test_no_auth(self) -> None:
self.try_fetch_profile(401)
def test_not_in_shared_room(self):
def test_not_in_shared_room(self) -> None:
self.ensure_requester_left_room()
self.try_fetch_profile(403, access_token=self.requester_tok)
def test_in_shared_room(self):
def test_in_shared_room(self) -> None:
self.ensure_requester_left_room()
self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
self.try_fetch_profile(200, self.requester_tok)
def try_fetch_profile(self, expected_code, access_token=None):
def try_fetch_profile(
self, expected_code: int, access_token: Optional[str] = None
) -> None:
self.request_profile(expected_code, access_token=access_token)
self.request_profile(
@ -363,13 +368,18 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
expected_code, url_suffix="/avatar_url", access_token=access_token
)
def request_profile(self, expected_code, url_suffix="", access_token=None):
def request_profile(
self,
expected_code: int,
url_suffix: str = "",
access_token: Optional[str] = None,
) -> None:
channel = self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
self.assertEqual(channel.code, expected_code, channel.result)
def ensure_requester_left_room(self):
def ensure_requester_left_room(self) -> None:
try:
self.helper.leave(
room=self.room_id, user=self.requester, tok=self.requester_tok
@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
profile.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_auth_for_profile_requests"] = True
config["limit_profile_requests_to_users_who_share_rooms"] = True
@ -397,12 +407,12 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# User requesting the profile.
self.requester = self.register_user("requester", "pass")
self.requester_tok = self.login("requester", "pass")
def test_can_lookup_own_profile(self):
def test_can_lookup_own_profile(self) -> None:
"""Tests that a user can lookup their own profile without having to be in a room
if 'require_auth_for_profile_requests' is set to true in the server's config.
"""

View File

@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
]
hijack_auth = False
def test_enabled_on_creation(self):
def test_enabled_on_creation(self) -> None:
"""
Tests the GET and PUT of push rules' `enabled` endpoints.
Tests that a rule is enabled upon creation, even though a rule with that
@ -56,7 +56,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
def test_enabled_on_recreation(self):
def test_enabled_on_recreation(self) -> None:
"""
Tests the GET and PUT of push rules' `enabled` endpoints.
Tests that a rule is enabled upon creation, even if a rule with that
@ -113,7 +113,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
def test_enabled_disable(self):
def test_enabled_disable(self) -> None:
"""
Tests the GET and PUT of push rules' `enabled` endpoints.
Tests that a rule is disabled and enabled when we ask for it.
@ -166,7 +166,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
def test_enabled_404_when_get_non_existent(self):
def test_enabled_404_when_get_non_existent(self) -> None:
"""
Tests that `enabled` gives 404 when the rule doesn't exist.
"""
@ -212,7 +212,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_enabled_404_when_get_non_existent_server_rule(self):
def test_enabled_404_when_get_non_existent_server_rule(self) -> None:
"""
Tests that `enabled` gives 404 when the server-default rule doesn't exist.
"""
@ -226,7 +226,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_enabled_404_when_put_non_existent_rule(self):
def test_enabled_404_when_put_non_existent_rule(self) -> None:
"""
Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
"""
@ -243,7 +243,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_enabled_404_when_put_non_existent_server_rule(self):
def test_enabled_404_when_put_non_existent_server_rule(self) -> None:
"""
Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
"""
@ -260,7 +260,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_actions_get(self):
def test_actions_get(self) -> None:
"""
Tests that `actions` gives you what you expect on a fresh rule.
"""
@ -289,7 +289,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
)
def test_actions_put(self):
def test_actions_put(self) -> None:
"""
Tests that PUT on actions updates the value you'd get from GET.
"""
@ -325,7 +325,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["actions"], ["dont_notify"])
def test_actions_404_when_get_non_existent(self):
def test_actions_404_when_get_non_existent(self) -> None:
"""
Tests that `actions` gives 404 when the rule doesn't exist.
"""
@ -365,7 +365,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_actions_404_when_get_non_existent_server_rule(self):
def test_actions_404_when_get_non_existent_server_rule(self) -> None:
"""
Tests that `actions` gives 404 when the server-default rule doesn't exist.
"""
@ -379,7 +379,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_actions_404_when_put_non_existent_rule(self):
def test_actions_404_when_put_non_existent_rule(self) -> None:
"""
Tests that `actions` gives 404 when putting to a rule that doesn't exist.
"""
@ -396,7 +396,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_actions_404_when_put_non_existent_server_rule(self):
def test_actions_404_when_put_non_existent_server_rule(self) -> None:
"""
Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
"""

View File

@ -11,9 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase):
sync.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["rc_message"] = {"per_second": 0.2, "burst_count": 10}
@ -36,7 +42,7 @@ class RedactionsTestCase(HomeserverTestCase):
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register a couple of users
self.mod_user_id = self.register_user("user1", "pass")
self.mod_access_token = self.login("user1", "pass")
@ -60,7 +66,9 @@ class RedactionsTestCase(HomeserverTestCase):
room=self.room_id, user=self.other_user_id, tok=self.other_access_token
)
def _redact_event(self, access_token, room_id, event_id, expect_code=200):
def _redact_event(
self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
) -> JsonDict:
"""Helper function to send a redaction event.
Returns the json body.
@ -71,13 +79,13 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertEqual(int(channel.result["code"]), expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token, room_id):
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
def test_redact_event_as_moderator(self):
def test_redact_event_as_moderator(self) -> None:
# as a regular user, send a message to redact
b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
msg_id = b["event_id"]
@ -98,7 +106,7 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
self.assertEqual(timeline[-2]["content"], {})
def test_redact_event_as_normal(self):
def test_redact_event_as_normal(self) -> None:
# as a regular user, send a message to redact
b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
normal_msg_id = b["event_id"]
@ -133,7 +141,7 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id)
self.assertEqual(timeline[-3]["content"], {})
def test_redact_nonexistent_event(self):
def test_redact_nonexistent_event(self) -> None:
# control case: an existing event
b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
msg_id = b["event_id"]
@ -158,7 +166,7 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
self.assertEqual(timeline[-2]["content"], {})
def test_redact_create_event(self):
def test_redact_create_event(self) -> None:
# control case: an existing event
b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token)
msg_id = b["event_id"]
@ -178,7 +186,7 @@ class RedactionsTestCase(HomeserverTestCase):
self.other_access_token, self.room_id, create_event_id, expect_code=403
)
def test_redact_event_as_moderator_ratelimit(self):
def test_redact_event_as_moderator_ratelimit(self) -> None:
"""Tests that the correct ratelimiting is applied to redactions"""
message_ids = []

View File

@ -18,11 +18,15 @@ import urllib.parse
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
from synapse.storage.relations import RelationPaginationToken
from synapse.types import JsonDict, StreamToken
from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
@ -52,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id, self.user_token = self._create_user("alice")
@ -63,7 +67,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
self.parent_id = res["event_id"]
def test_send_relation(self):
def test_send_relation(self) -> None:
"""Tests that sending a relation using the new /send_relation works
creates the right shape of event.
"""
@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body,
)
def test_deny_invalid_event(self):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
channel = self._send_relation(
RelationTypes.ANNOTATION,
@ -125,7 +129,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self):
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@ -138,7 +142,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self):
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEqual(200, channel.code, channel.json_body)
@ -146,7 +150,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEqual(400, channel.code, channel.json_body)
def test_deny_forked_thread(self):
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
channel = self._send_relation(
RelationTypes.THREAD,
@ -165,7 +169,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEqual(200, channel.code, channel.json_body)
@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
).to_string(self.store)
)
def test_repeated_paginate_relations(self):
def test_repeated_paginate_relations(self) -> None:
"""Test that if we paginate using a limit and tokens then we get the
expected events.
"""
@ -303,7 +307,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
def test_pagination_from_sync_and_messages(self):
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
self.assertEqual(200, channel.code, channel.json_body)
@ -362,7 +366,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
def test_aggregation_pagination_groups(self):
def test_aggregation_pagination_groups(self) -> None:
"""Test that we can paginate annotation groups correctly."""
# We need to create ten separate users to send each reaction.
@ -427,7 +431,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(sent_groups, found_groups)
def test_aggregation_pagination_within_group(self):
def test_aggregation_pagination_within_group(self) -> None:
"""Test that we can paginate within an annotation group."""
# We need to create ten separate users to send each reaction.
@ -524,7 +528,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
def test_aggregation(self):
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@ -556,7 +560,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
},
)
def test_aggregation_redactions(self):
def test_aggregation_redactions(self) -> None:
"""Test that annotations get correctly aggregated after a redaction."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@ -590,7 +594,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
def test_aggregation_must_be_annotation(self):
def test_aggregation_must_be_annotation(self) -> None:
"""Test that aggregations must be annotations."""
channel = self.make_request(
@ -604,7 +608,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
)
def test_bundled_aggregations(self):
def test_bundled_aggregations(self) -> None:
"""
Test that annotations, references, and threads get correctly bundled.
@ -746,7 +750,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
assert_bundle(self._find_event_in_chunk(chunk))
def test_aggregation_get_event_for_annotation(self):
def test_aggregation_get_event_for_annotation(self) -> None:
"""Test that annotations do not get bundled aggregations included
when directly requested.
"""
@ -768,7 +772,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
def test_aggregation_get_event_for_thread(self):
def test_aggregation_get_event_for_thread(self) -> None:
"""Test that threads get bundled aggregations included when directly requested."""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEqual(200, channel.code, channel.json_body)
@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self):
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@ -927,7 +931,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("m.relations", channel.json_body["unsigned"])
@unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
def test_edit(self):
def test_edit(self) -> None:
"""Test that a simple edit works."""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
assert_bundle(self._find_event_in_chunk(chunk))
def test_multi_edit(self):
def test_multi_edit(self) -> None:
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
"""
@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_edit_reply(self):
def test_edit_reply(self) -> None:
"""Test that editing a reply works."""
# Create a reply to edit.
@ -1124,7 +1128,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self):
def test_edit_thread(self) -> None:
"""Test that editing a thread works."""
# Create a thread and edit the last event.
@ -1163,7 +1167,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
latest_event_in_thread = thread_summary["latest_event"]
self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
def test_edit_edit(self):
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"}
channel = self._send_relation(
@ -1213,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_relations_redaction_redacts_edits(self):
def test_relations_redaction_redacts_edits(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
@ -1269,7 +1273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
def test_aggregations_redaction_prevents_access_to_aggregations(self):
def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
"""Test that annotations of an event are redacted when the original event
is redacted.
"""
@ -1309,7 +1313,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
def test_unknown_relations(self):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
self.assertEqual(200, channel.code, channel.json_body)
@ -1417,7 +1421,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return user_id, access_token
def test_background_update(self):
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)

View File

@ -13,9 +13,14 @@
# limitations under the License.
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["retention"] = {
"enabled": True,
@ -47,7 +52,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
return self.hs
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
@ -55,7 +60,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.serializer = self.hs.get_event_client_serializer()
self.clock = self.hs.get_clock()
def test_retention_event_purged_with_state_event(self):
def test_retention_event_purged_with_state_event(self) -> None:
"""Tests that expired events are correctly purged when the room's retention policy
is defined by a state event.
"""
@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
def test_retention_event_purged_with_state_event_outside_allowed(self):
def test_retention_event_purged_with_state_event_outside_allowed(self) -> None:
"""Tests that the server configuration can override the policy for a room when
running the purge jobs.
"""
@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# instead of the one specified in the room's policy.
self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
def test_retention_event_purged_without_state_event(self) -> None:
"""Tests that expired events are correctly purged when the room's retention policy
is defined by the server's configuration's default retention policy.
"""
@ -110,7 +115,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
def test_visibility(self):
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
outdated events
"""
@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
def _test_retention_event_purged(self, room_id: str, increment: float):
def _test_retention_event_purged(self, room_id: str, increment: float) -> None:
"""Run the following test scenario to test the message retention policy support:
1. Send event 1
@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
expired_event_id = resp.get("event_id")
assert expired_event_id is not None
# Check that we can retrieve the event.
expired_event = self.get_event(expired_event_id)
@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
assert valid_event_id is not None
# Advance the time again. Now our first event should have expired but our second
# one should still be kept.
@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# has been purged.
self.get_event(room_id, create_event.event_id)
def get_event(self, event_id, expect_none=False):
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True))
if expect_none:
@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["retention"] = {
"enabled": True,
@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
)
return self.hs
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
def test_no_default_policy(self):
def test_no_default_policy(self) -> None:
"""Tests that an event doesn't get expired if there is neither a default retention
policy nor a policy specific to the room.
"""
@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
self._test_retention(room_id)
def test_state_policy(self):
def test_state_policy(self) -> None:
"""Tests that an event gets correctly expired if there is no default retention
policy but there's a policy specific to the room.
"""
@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
self._test_retention(room_id, expected_code_for_first_event=404)
def _test_retention(self, room_id, expected_code_for_first_event=200):
def _test_retention(
self, room_id: str, expected_code_for_first_event: int = 200
) -> None:
# Send a first event to the room. This is the event we'll want to be purged at the
# end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
assert first_event_id is not None
# Check that we can retrieve the event.
expired_event = self.get_event(room_id, first_event_id)
@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
second_event_id = resp.get("event_id")
assert second_event_id is not None
# Advance the time by another month.
self.reactor.advance(one_day_ms * 30 / 1000)
@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
second_event = self.get_event(room_id, second_event_id)
self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
def get_event(self, room_id, event_id, expected_code=200):
def get_event(
self, room_id: str, event_id: str, expected_code: int = 200
) -> JsonDict:
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
channel = self.make_request("GET", url, access_token=self.token)

View File

@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
def test_user_to_user(self):
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
@ -73,7 +73,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_local_room_key_request(self):
def test_local_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from local user"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
@ -128,7 +128,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_remote_room_key_request(self):
def test_remote_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from remote user"""
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
@ -199,7 +199,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
},
)
def test_limited_sync(self):
def test_limited_sync(self) -> None:
"""If a limited sync for to-devices happens the next /sync should respond immediately."""
self.register_user("u1", "pass")

View File

@ -14,6 +14,8 @@
from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import (
@ -23,13 +25,15 @@ from synapse.rest.client import (
room,
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
class _ShadowBannedBase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Create two users, one of which is shadow-banned.
self.banned_user_id = self.register_user("banned", "test")
self.banned_access_token = self.login("banned", "test")
@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase):
room_upgrade_rest_servlet.register_servlets,
]
def test_invite(self):
def test_invite(self) -> None:
"""Invites from shadow-banned users don't actually get sent."""
# The create works fine.
@ -77,7 +81,7 @@ class RoomTestCase(_ShadowBannedBase):
)
self.assertEqual(invited_rooms, [])
def test_invite_3pid(self):
def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler()
identity_handler.lookup_3pid = Mock(
@ -101,7 +105,7 @@ class RoomTestCase(_ShadowBannedBase):
# This should have raised an error earlier, but double check this wasn't called.
identity_handler.lookup_3pid.assert_not_called()
def test_create_room(self):
def test_create_room(self) -> None:
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
channel = self.make_request(
@ -126,7 +130,7 @@ class RoomTestCase(_ShadowBannedBase):
users = self.get_success(self.store.get_users_in_room(room_id))
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
def test_message(self):
def test_message(self) -> None:
"""Messages from shadow-banned users don't actually get sent."""
room_id = self.helper.create_room_as(
@ -151,7 +155,7 @@ class RoomTestCase(_ShadowBannedBase):
)
self.assertNotIn(event_id, latest_events)
def test_upgrade(self):
def test_upgrade(self) -> None:
"""A room upgrade should fail, but look like it succeeded."""
# The create works fine.
@ -177,7 +181,7 @@ class RoomTestCase(_ShadowBannedBase):
# The summary should be empty since the room doesn't exist.
self.assertEqual(summary, {})
def test_typing(self):
def test_typing(self) -> None:
"""Typing notifications should not be propagated into the room."""
# The create works fine.
room_id = self.helper.create_room_as(
@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase):
room.register_servlets,
]
def test_displayname(self):
def test_displayname(self) -> None:
"""Profile changes should succeed, but don't end up in a room."""
original_display_name = "banned"
new_display_name = "new name"
@ -281,7 +285,7 @@ class ProfileTestCase(_ShadowBannedBase):
event.content, {"membership": "join", "displayname": original_display_name}
)
def test_room_displayname(self):
def test_room_displayname(self) -> None:
"""Changes to state events for a room should be processed, but not end up in the room."""
original_display_name = "banned"
new_display_name = "new name"

View File

@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, room, shared_rooms
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
shared_rooms.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler()
def _get_shared_rooms(self, token, other_user) -> FakeChannel:
def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
@ -47,14 +51,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
access_token=token,
)
def test_shared_room_list_public(self):
def test_shared_room_list_public(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is public.
"""
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self):
def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
@ -63,7 +67,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
room_one_is_public=False, room_two_is_public=False
)
def test_shared_room_list_mixed(self):
def test_shared_room_list_mixed(self) -> None:
"""
The shared room list between two users should contain both public and private
rooms.
@ -72,7 +76,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
def _check_shared_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
):
) -> None:
"""Checks that shared public or private rooms between two users appear in
their shared room lists
"""
@ -109,7 +113,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
for room_id_id in channel.json_body["joined"]:
self.assertIn(room_id_id, [room_id_one, room_id_two])
def test_shared_room_list_after_leave(self):
def test_shared_room_list_after_leave(self) -> None:
"""
A room should no longer be considered shared if the other
user has left it.

View File

@ -13,11 +13,14 @@
# limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventContentFields, EventTypes, RoomTypes
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.rest import admin
from synapse.rest.client import login, room, room_upgrade_rest_servlet
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
@ -31,7 +34,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
room_upgrade_rest_servlet.register_servlets,
]
def prepare(self, reactor, clock, hs: "HomeServer"):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.creator = self.register_user("creator", "pass")
@ -60,7 +63,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
access_token=token or self.creator_token,
)
def test_upgrade(self):
def test_upgrade(self) -> None:
"""
Upgrading a room should work fine.
"""
@ -68,7 +71,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.result)
self.assertIn("replacement_room", channel.json_body)
def test_not_in_room(self):
def test_not_in_room(self) -> None:
"""
Upgrading a room should work fine.
"""
@ -79,7 +82,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(roomless_token)
self.assertEqual(403, channel.code, channel.result)
def test_power_levels(self):
def test_power_levels(self) -> None:
"""
Another user can upgrade the room if their power level is increased.
"""
@ -105,7 +108,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(self.other_token)
self.assertEqual(200, channel.code, channel.result)
def test_power_levels_user_default(self):
def test_power_levels_user_default(self) -> None:
"""
Another user can upgrade the room if the default power level for users is increased.
"""
@ -131,7 +134,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(self.other_token)
self.assertEqual(200, channel.code, channel.result)
def test_power_levels_tombstone(self):
def test_power_levels_tombstone(self) -> None:
"""
Another user can upgrade the room if they can send the tombstone event.
"""
@ -164,7 +167,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
)
self.assertNotIn(self.other, power_levels["users"])
def test_space(self):
def test_space(self) -> None:
"""Test upgrading a space."""
# Create a space.

View File

@ -41,6 +41,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
@ -48,15 +49,15 @@ from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
@attr.s
@attr.s(auto_attribs=True)
class RestHelper:
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
hs = attr.ib()
site = attr.ib(type=Site)
auth_user_id = attr.ib()
hs: HomeServer
site: Site
auth_user_id: Optional[str]
@overload
def create_room_as(
@ -145,7 +146,7 @@ class RestHelper:
def invite(
self,
room: Optional[str] = None,
room: str,
src: Optional[str] = None,
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
@ -216,7 +217,7 @@ class RestHelper:
def leave(
self,
room: Optional[str] = None,
room: str,
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
@ -230,14 +231,22 @@ class RestHelper:
expect_code=expect_code,
)
def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None:
def ban(
self,
room: str,
src: str,
targ: str,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.BAN,
**kwargs,
expect_code=expect_code,
)
def change_membership(
@ -378,7 +387,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Optional[Dict[str, Any]],
tok: str,
tok: Optional[str],
expect_code: int = HTTPStatus.OK,
state_key: str = "",
method: str = "GET",
@ -458,7 +467,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Dict[str, Any],
tok: str,
tok: Optional[str],
expect_code: int = HTTPStatus.OK,
state_key: str = "",
) -> JsonDict:
@ -658,7 +667,12 @@ class RestHelper:
(TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
async def mock_req(
method: str,
uri: str,
data: Optional[dict] = None,
headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(