docstrings and unittests for storage.state (#3958)

I spent ages trying to figure out how I was going mad...
This commit is contained in:
Richard van der Hoff 2018-09-27 11:22:25 +01:00 committed by GitHub
parent 2c695fd1aa
commit ae6ad4cf41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 8 deletions

1
changelog.d/3958.misc Normal file
View File

@ -0,0 +1 @@
Fix docstrings and add tests for state store methods

View File

@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, _room_id, event_ids):
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids: if not event_ids:
defer.returnValue({}) defer.returnValue({})
@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_group(self, state_group): def get_state_ids_for_group(self, state_group):
"""Get the state IDs for the given state group """Get the event IDs of all the state in the given state group
Args: Args:
state_group (int) state_group (int)
@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_state_groups(self, room_id, event_ids): def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events. Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
""" """
if not event_ids: if not event_ids:
defer.returnValue({}) defer.returnValue({})
@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
member events (if True), or to exclude member events (if False) member events (if True), or to exclude member events (if False)
Returns: Returns:
dictionary state_group -> (dict of (type, state_key) -> event id) Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
results = {} results = {}
@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events. If None, `types` filtering is applied to all events.
Returns: Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]] Deferred[dict[int, dict[tuple[str, str], str]]]:
a dictionary mapping from state group to state dictionary. dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if types is not None: if types is not None:
non_member_types = [t for t in types if t[0] != EventTypes.Member] non_member_types = [t for t in types if t[0] != EventTypes.Member]
@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events. If None, `types` filtering is applied to all events.
Returns: Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]] Deferred[dict[int, dict[tuple[str, str], str]]]:
a dictionary mapping from state group to state dictionary. dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if types: if types:
types = frozenset(types) types = frozenset(types)

View File

@ -74,6 +74,45 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2)) self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_groups_ids(self):
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, '', {}
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
)
state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
self.assertDictEqual(
state_map,
{
(EventTypes.Create, ''): e1.event_id,
(EventTypes.Name, ''): e2.event_id,
},
)
@defer.inlineCallbacks
def test_get_state_groups(self):
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, '', {}
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
)
state_group_map = yield self.store.get_state_groups(
self.room, [e2.event_id])
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual(
{ev.event_id for ev in state_list},
{e1.event_id, e2.event_id},
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_state_for_event(self): def test_get_state_for_event(self):