Proper types for tests.module_api (#15031)

* -> None for test methods

* A first batch of type fixes

* Introduce common parent test case

* Fixup that big test method

* tests.module_api passes mypy

* Changelog
This commit is contained in:
David Robertson 2023-02-09 00:23:35 +00:00 committed by GitHub
parent 30509a1010
commit 7081bb56e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 54 deletions

1
changelog.d/15031.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,7 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
|tests/module_api/test_api.py
|tests/server.py
)$

View File

@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, override_config
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
override_config,
)
@attr.s
@ -470,7 +474,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update(
testcase: FederatingHomeserverTestCase,
testcase: HomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@ -491,7 +495,7 @@ def send_presence_update(
def sync_presence(
testcase: FederatingHomeserverTestCase,
testcase: HomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester
from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
class ModuleApiTestCase(HomeserverTestCase):
class BaseModuleApiTestCase(HomeserverTestCase):
"""Common properties of the two test case classes."""
module_api: ModuleApi
# These are all written by _test_sending_local_online_presence_to_local_user.
presence_receiver_id: str
presence_receiver_tok: str
presence_sender_id: str
presence_sender_tok: str
class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@ -42,14 +59,14 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastores().main
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
self.sync_handler = homeserver.get_sync_handler()
self.auth_handler = homeserver.get_auth_handler()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.module_api = hs.get_module_api()
self.event_creation_handler = hs.get_event_creation_handler()
self.sync_handler = hs.get_sync_handler()
self.auth_handler = hs.get_auth_handler()
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
@ -58,7 +75,7 @@ class ModuleApiTestCase(HomeserverTestCase):
federation_transport_client=fed_transport_client,
)
def test_can_register_user(self):
def test_can_register_user(self) -> None:
"""Tests that an external module can register a user"""
# Register a new user
user_id, access_token = self.get_success(
@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self):
def test_can_register_admin_user(self) -> None:
user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
def test_can_set_admin(self):
def test_can_set_admin(self) -> None:
user_id = self.register_user(
"alice_wants_admin",
"1234",
@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
def test_can_set_displayname(self):
def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname"
user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False
)
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_userinfo is not None
self.get_success(
self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False
@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob")
def test_get_userinfo_by_id(self):
def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
def test_get_userinfo_by_id__no_user_found(self):
def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
def test_get_user_ip_and_agents(self):
def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user.
@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent.
info = self.get_success(
self.module_api.get_user_ip_and_agents(
user_id, (last_seen_1 + last_seen_2) / 2
user_id, (last_seen_1 + last_seen_2) // 2
)
)
self.assertEqual(len(info), 1)
@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
def test_get_user_ip_and_agents__no_user_found(self):
def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success(
self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com"
@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
def test_sending_events_into_room(self):
def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
self.event_creation_handler.create_and_send_nonmember_event = Mock(
self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event,
)
@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event
content = {"body": "I am a puppet", "msgtype": "m.text"}
content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = {
"room_id": room_id,
"type": "m.room.message",
@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
event: EventBase = self.get_success(
event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
)
self.assertEqual(event.sender, user_id)
@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception
)
def test_public_rooms(self):
def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried.
"""
@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
def test_send_local_online_presence_to(self):
def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
def test_send_local_online_presence_to_federation(self):
def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@ -431,7 +451,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
def test_update_membership(self):
def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme")
@ -554,7 +574,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
def test_update_room_membership_remote_join(self):
def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
@ -582,7 +602,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
def test_get_room_state(self):
def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
@ -677,7 +697,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions({"foo": "bar"})
self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"])
@ -756,7 +776,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias)
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
servlets = [
@ -766,7 +786,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets,
]
def default_config(self):
def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
@ -774,18 +794,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
def prepare(self, reactor, clock, homeserver):
self.module_api = homeserver.get_module_api()
self.sync_handler = homeserver.get_sync_handler()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.module_api = hs.get_module_api()
self.sync_handler = hs.get_sync_handler()
def test_send_local_online_presence_to_workers(self):
def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user(
test_case: HomeserverTestCase, test_with_workers: bool = False
):
test_case: BaseModuleApiTestCase, test_with_workers: bool = False
) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases:
@ -852,6 +872,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to.
assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate()
# Syncing again should result in no presence updates
@ -868,6 +889,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers:
assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api()
else:
module_api_to_use = test_case.module_api
@ -875,12 +897,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it.
# Note that we're syncing via the master.
d = module_api_to_use.send_local_online_presence_to(
[
test_case.presence_receiver_id,
]
d = defer.ensureDeferred(
module_api_to_use.send_local_online_presence_to(
[test_case.presence_receiver_id],
)
)
d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
@ -897,7 +918,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
presence_update: UserPresenceState = presence_updates[0]
presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@ -908,7 +929,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
presence_update: UserPresenceState = presence_updates[0]
presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@ -936,12 +957,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence.
d = module_api_to_use.send_local_online_presence_to(
d = defer.ensureDeferred(
module_api_to_use.send_local_online_presence_to(
[
test_case.presence_receiver_id,
]
)
d = defer.ensureDeferred(d)
)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the