Use HTTPStatus constants in place of literals in tests. (#13297)

This commit is contained in:
Dirk Klimpel 2022-07-15 21:31:27 +02:00 committed by GitHub
parent 7b67e93d49
commit 96cf81e312
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 308 additions and 238 deletions

1
changelog.d/13297.misc Normal file
View File

@ -0,0 +1 @@
Use `HTTPStatus` constants in place of literals in tests.

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus
from parameterized import parameterized
@ -58,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@ -119,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -153,7 +154,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def test_send_join(self):
@ -171,7 +172,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# we should get complete room state back
returned_state = [
@ -226,7 +227,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# expect a reduced room state
returned_state = [

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
@ -255,7 +256,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier,
),
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
@ -293,7 +294,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]

View File

@ -14,6 +14,7 @@
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock
@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:

View File

@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])

View File

@ -15,6 +15,7 @@ import json
import os
import re
from email.parser import Parser
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@ -98,7 +99,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@ -347,7 +348,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@ -362,7 +363,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@ -390,7 +391,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
expected_code: int = 200,
expected_code: int = HTTPStatus.OK,
) -> None:
channel = self.make_request(
"POST",
@ -715,7 +716,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@ -725,7 +728,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@ -747,7 +750,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@ -756,7 +759,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@ -781,7 +784,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@ -791,7 +796,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@ -817,7 +822,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@ -827,7 +834,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@ -852,7 +859,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@ -862,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@ -872,7 +881,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
expect_code=200,
expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@ -884,7 +893,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
expect_code=200,
expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@ -895,7 +904,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
expect_code=400,
expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@ -907,28 +916,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
expect_code=200,
expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
expect_code=200,
expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
expect_code=200,
expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
expect_code=400,
expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": []})
@ -940,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
expect_code=400,
expect_code=HTTPStatus.BAD_REQUEST,
)
def _request_token(
@ -948,7 +957,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
expect_code: int = 200,
expect_code: int = HTTPStatus.OK,
) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
@ -993,7 +1002,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
@ -1002,7 +1013,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@ -1052,7 +1063,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@ -1061,7 +1072,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@ -1092,7 +1103,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
expected_status_code=400,
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM,
)
@ -1100,7 +1111,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
expected_status_code=400,
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM,
)
@ -1286,7 +1297,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
expected_status_code: int = 200,
expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,

View File

@ -14,6 +14,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@ -261,20 +262,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@ -288,7 +289,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@ -296,7 +297,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
@ -307,7 +308,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
@ -330,7 +331,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@ -341,14 +342,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@ -367,14 +368,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@ -466,7 +467,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
expected_flow_types = [
"m.login.cas",
@ -494,14 +495,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8")
@ -530,7 +531,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
@ -555,7 +556,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
@ -579,7 +580,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@ -606,7 +607,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
@ -634,7 +635,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
@ -643,18 +644,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@ -765,7 +766,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@ -1246,7 +1247,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
@ -1290,7 +1291,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))),
],
)
self.assertEqual(chan.code, 302, chan.result)
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@ -1300,7 +1301,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
self.assertEqual(chan.code, 302, chan.result)
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@ -1325,5 +1326,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")

File diff suppressed because it is too large Load Diff