Require types in tests.storage. (#14646)

Adds missing type hints to `tests.storage` package
and does not allow untyped definitions.
This commit is contained in:
Patrick Cloke 2022-12-09 12:36:32 -05:00 committed by GitHub
parent 94bc21e69f
commit 3ac412b4e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 489 additions and 341 deletions

View file

@ -16,10 +16,15 @@ import logging
from frozendict import frozendict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.server import HomeServer
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase
@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)
def inject_state_event(self, room, sender, typ, state_key, content):
def inject_state_event(
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
def assertStateMapEqual(self, s1, s2):
def assertStateMapEqual(
self, s1: StateMap[EventBase], s2: StateMap[EventBase]
) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
def test_get_state_groups_ids(self):
def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
self.storage.state.get_state_groups_ids(
self.room.to_string(), [e2.event_id]
)
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
def test_get_state_groups(self):
def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self):
def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
):
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(self):
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self):
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
def test_state_filter_difference_include_other_minus_include_other(self):
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self):
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
def test_state_filter_difference_simple_cases(self):
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
class StateFilterTestCase(TestCase):
def test_return_expanded(self):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).