Add type hints to tests/rest/client (#12108)

* Add type hints to `tests/rest/client`

* newsfile

* fix imports

* add `test_account.py`

* Remove one type hint in `test_report_event.py`

* change `on_create_room` to `async`

* update new functions in `test_third_party_rules.py`

* Add `test_filter.py`

* add `test_rooms.py`

* change to `assertEquals` to `assertEqual`

* lint
This commit is contained in:
Dirk Klimpel 2022-03-02 17:34:14 +01:00 committed by GitHub
parent b4461e7d8a
commit 2ffaf30803
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 421 additions and 350 deletions

View file

@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap
from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
@ -34,7 +40,7 @@ thread_local = threading.local()
class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return True
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> Union[bool, dict]:
return True
@staticmethod
def parse_config(config):
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
def on_create_room(
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> JsonDict:
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
account.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
load_legacy_third_party_event_rules(hs)
@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu):
async def approve_all_signature_checking(
_: RoomVersion, pdu: EventBase
) -> EventBase:
return pdu
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(origin, event, context, *args, **kwargs):
async def _check_event_auth(
origin: str,
event: EventBase,
context: EventContext,
*args: Any,
**kwargs: Any,
) -> EventContext:
return context
hs.get_federation_event_handler()._check_event_auth = _check_event_auth
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return hs
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
except Exception:
pass
def test_third_party_rules(self):
def test_third_party_rules(self) -> None:
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.result["code"], b"403", channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self):
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
is functional: that SynapseErrors are passed through from check_event_allowed
@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
def error_dict(self):
def error_dict(self) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return result
# add a callback that will raise our hacky exception
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
)
def test_cannot_modify_event(self):
def test_cannot_modify_event(self) -> None:
"""cannot accidentally modify an event before it is persisted"""
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
ev.content = {"x": "y"}
return True, None
@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# 500 Internal Server Error
self.assertEqual(channel.code, 500, channel.result)
def test_modify_event(self):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {"x": "y"}
return True, d
@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
def test_message_edit(self):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {
"msgtype": "m.text",
@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
def test_send_event(self) -> None:
"""Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
def test_legacy_check_event_allowed(self):
def test_legacy_check_event_allowed(self) -> None:
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
def test_legacy_on_create_room(self):
def test_legacy_on_create_room(self) -> None:
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
def test_sent_event_end_up_in_room_state(self):
def test_sent_event_end_up_in_room_state(self) -> None:
"""Tests that a state event sent by a module while processing another state event
doesn't get dropped from the state of the room. This is to guard against a bug
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
api = self.hs.get_module_api()
# Define a callback that sends a custom event on power levels update.
async def test_fn(event: EventBase, state_events):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
if event.is_state and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i)
def test_on_new_event(self):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
def _update_power_levels(self, event_default: int = 0):
def _update_power_levels(self, event_default: int = 0) -> None:
"""Updates the room's power levels.
Args:
@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
tok=self.tok,
)
def test_on_profile_update(self):
def test_on_profile_update(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates.
"""
@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_profile_update_admin(self):
def test_on_profile_update_admin(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates triggered by a server admin.
"""
@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_user_deactivation_status_changed(self):
def test_on_user_deactivation_status_changed(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation.
"""
@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
args = profile_mock.call_args[0]
self.assertTrue(args[3])
def test_on_user_deactivation_status_changed_admin(self):
def test_on_user_deactivation_status_changed_admin(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation triggered by a server admin as
well as a reactivation.