From 54e74cc15f30585f5874780437614c0df6f639d9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 24 Feb 2022 19:56:38 +0100 Subject: [PATCH] Add type hints to `tests/rest/client` (#12072) --- changelog.d/12072.misc | 1 + tests/rest/client/test_consent.py | 19 +++-- tests/rest/client/test_device_lists.py | 16 ++-- tests/rest/client/test_ephemeral_message.py | 19 +++-- tests/rest/client/test_identity.py | 13 +++- tests/rest/client/test_keys.py | 6 +- tests/rest/client/test_password_policy.py | 39 +++++----- tests/rest/client/test_power_levels.py | 47 +++++++----- tests/rest/client/test_presence.py | 15 ++-- tests/rest/client/test_room_batch.py | 2 +- tests/rest/client/utils.py | 85 +++++++++++++-------- 11 files changed, 160 insertions(+), 102 deletions(-) create mode 100644 changelog.d/12072.misc diff --git a/changelog.d/12072.misc b/changelog.d/12072.misc new file mode 100644 index 000000000..0360dbd61 --- /dev/null +++ b/changelog.d/12072.misc @@ -0,0 +1 @@ +Add type hints to `tests/rest/client`. diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index fcdc56581..b1ca81a91 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -13,11 +13,16 @@ # limitations under the License. import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder from synapse.rest.client import login, room from synapse.rest.consent import consent_resource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["form_secret"] = "123abc" @@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_render_public_consent(self): + def test_render_public_consent(self) -> None: """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( @@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): "/consent?v=1", shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) - def test_accept_consent(self): + def test_accept_consent(self) -> None: """ A user can use the consent form to accept the terms. """ @@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and whether we've consented version, consented = channel.result["body"].decode("ascii").split(",") @@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Fetch the consent page, to get the consent version -- it should have # changed @@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py index 16070cf02..a8af4e243 100644 --- a/tests/rest/client/test_device_lists.py +++ b/tests/rest/client/test_device_lists.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, login, register @@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_receiving_local_device_list_changes(self): + def test_receiving_local_device_list_changes(self) -> None: """Tests that a local users that share a room receive each other's device list changes. """ @@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync contains the updated device list. # If not, the client would only receive the device list update on the @@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): ) self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - def test_not_receiving_local_device_list_changes(self): + def test_not_receiving_local_device_list_changes(self) -> None: """Tests a local users DO NOT receive device updates from each other if they do not share a room. """ @@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): "/sync", access_token=bob_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) next_batch_token = channel.json_body["next_batch"] # ...and then an incremental sync. This should block until the sync stream is woken up, @@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): }, access_token=alice_access_token, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that bob's incremental sync does not contain the updated device list. bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( "changed", [] diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 3d7aa8ec8..9fa1f82df 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_ephemeral_messages"] = True @@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_message_expiry_no_delay(self): + def test_message_expiry_no_delay(self) -> None: """Tests that sending a message sent with a m.self_destruct_after field set to the past results in that event being deleted right away. """ @@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def test_message_expiry_delay(self): + def test_message_expiry_delay(self) -> None: """Tests that sending a message with a m.self_destruct_after field set to the future results in that event not being deleted right away, but advancing the clock to after that expiry timestamp causes the event to be deleted. @@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index becb4e8dc..299b9d21e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -13,9 +13,14 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_3pid_lookup"] = False @@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs - def test_3pid_lookup_disabled(self): + def test_3pid_lookup_disabled(self) -> None: self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] params = { @@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d7fa635ea..bbc8e7424 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def test_rejects_device_id_ice_key_outside_of_list(self): + def test_rejects_device_id_ice_key_outside_of_list(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_rejects_device_key_given_as_map_to_bool(self): + def test_rejects_device_key_given_as_map_to_bool(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_requires_device_key(self): + def test_requires_device_key(self) -> None: """`device_keys` is required. We should complain if it's missing.""" self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3cf587189..3a74d2e96 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -13,11 +13,16 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import account, login, password_policy, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.register_url = "/_matrix/client/r0/register" self.policy = { "enabled": True, @@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_get_policy(self): + def test_get_policy(self) -> None: """Tests if the /password_policy endpoint returns the configured policy.""" channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual( channel.json_body, { @@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_password_too_short(self): + def test_password_too_short(self) -> None: request_data = json.dumps({"username": "kermit", "password": "shorty"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, ) - def test_password_no_digit(self): + def test_password_no_digit(self) -> None: request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, ) - def test_password_no_symbol(self): + def test_password_no_symbol(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, ) - def test_password_no_uppercase(self): + def test_password_no_uppercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, ) - def test_password_no_lowercase(self): + def test_password_no_lowercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, ) - def test_password_compliant(self): + def test_password_compliant(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) - def test_password_change(self): + def test_password_change(self) -> None: """This doesn't test every possible use case, only that hitting /account/password triggers the password validation code. """ @@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): access_token=tok, ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index c0de4c93a..27dcfc83d 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a room admin, moderator and regular user self.admin_user_id = self.register_user("admin", "pass") self.admin_access_token = self.login("admin", "pass") @@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, ) - def test_non_admins_cannot_enable_room_encryption(self): + def test_non_admins_cannot_enable_room_encryption(self) -> None: # have the mod try to enable room encryption self.helper.send_state( self.room_id, @@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_send_server_acl(self): + def test_non_admins_cannot_send_server_acl(self) -> None: # have the mod try to send a server ACL self.helper.send_state( self.room_id, @@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a server ACL @@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_tombstone_room(self): + def test_non_admins_cannot_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a tombstone event @@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase): expect_code=403, # expect failure ) - def test_admins_can_enable_room_encryption(self): + def test_admins_can_enable_room_encryption(self) -> None: # have the admin try to enable room encryption self.helper.send_state( self.room_id, "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_send_server_acl(self): + def test_admins_can_send_server_acl(self) -> None: # have the admin try to send a server ACL self.helper.send_state( self.room_id, @@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_tombstone_room(self): + def test_admins_can_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_cannot_set_string_power_levels(self): + def test_cannot_set_string_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_large_power_levels(self): + def test_cannot_set_unsafe_large_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_small_power_levels(self): + def test_cannot_set_unsafe_small_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 56fe1a3d0..0abe378fe 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [presence.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) @@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): return hs - def test_put_presence(self): + def test_put_presence(self) -> None: """ PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. @@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): + def test_put_presence_disabled(self) -> None: """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. @@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index e9f870403..44f333a0e 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): return room_id, event_id_a, event_id_b, event_id_c @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self): + def test_same_state_groups_for_whole_historical_batch(self) -> None: """Make sure that when using the `/batch_send` endpoint to import a bunch of historical messages, it re-uses the same `state_group` across the whole batch. This is an easy optimization to make sure we're getting diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 2b3fdadff..46cd5f70a 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,6 +19,7 @@ import json import re import time import urllib.parse +from http import HTTPStatus from typing import ( Any, AnyStr, @@ -89,7 +90,7 @@ class RestHelper: is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> Optional[str]: @@ -137,12 +138,19 @@ class RestHelper: assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - if expect_code == 200: + if expect_code == HTTPStatus.OK: return channel.json_body["room_id"] else: return None - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + def invite( + self, + room: Optional[str] = None, + src: Optional[str] = None, + targ: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=src, @@ -156,7 +164,7 @@ class RestHelper: self, room: str, user: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, ) -> None: @@ -170,7 +178,14 @@ class RestHelper: expect_code=expect_code, ) - def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + def knock( + self, + room: Optional[str] = None, + user: Optional[str] = None, + reason: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: temp_id = self.auth_user_id self.auth_user_id = user path = "/knock/%s" % room @@ -199,7 +214,13 @@ class RestHelper: self.auth_user_id = temp_id - def leave(self, room=None, user=None, expect_code=200, tok=None): + def leave( + self, + room: Optional[str] = None, + user: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, @@ -209,7 +230,7 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object): + def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, @@ -228,7 +249,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, ) -> None: """ @@ -294,13 +315,13 @@ class RestHelper: def send( self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, + room_id: str, + body: Optional[str] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if body is None: body = "body_text_here" @@ -318,14 +339,14 @@ class RestHelper: def send_event( self, - room_id, - type, + room_id: str, + type: str, content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -358,10 +379,10 @@ class RestHelper: event_type: str, body: Optional[Dict[str, Any]], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", - ) -> Dict: + ) -> JsonDict: """Read or write some state from a given room Args: @@ -410,9 +431,9 @@ class RestHelper: room_id: str, event_type: str, tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Gets some state from a room Args: @@ -438,9 +459,9 @@ class RestHelper: event_type: str, body: Dict[str, Any], tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Set some state in a room Args: @@ -467,8 +488,8 @@ class RestHelper: image_data: bytes, tok: str, filename: str = "test.png", - expect_code: int = 200, - ) -> dict: + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: """Upload a piece of test media to the media repo Args: resource: The resource that will handle the upload request @@ -513,7 +534,7 @@ class RestHelper: channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200, channel.result + assert channel.code == HTTPStatus.OK, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -532,7 +553,7 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == 200 + assert channel.code == HTTPStatus.OK return channel.json_body def auth_via_oidc( @@ -641,7 +662,7 @@ class RestHelper: (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, + code=HTTPStatus.OK, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), ) @@ -739,7 +760,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint ) # that should serve a confirmation page - assert channel.code == 200, channel.text_body + assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) # parse the confirmation page to fish out the link.