Convert state resolution to async/await (#7942)

This commit is contained in:
Patrick Cloke 2020-07-24 10:59:51 -04:00 committed by GitHub
parent e739b20588
commit b975fa2e99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 198 additions and 184 deletions

View file

@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
return state_group
return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
return {
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event
}
return defer.succeed(
{
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event
}
)
def get_state_group_delta(self, name):
return None, None
return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(event, old_state=old_state)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
context = yield self.state.compute_event_context(event, old_state=old_state)
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
group_name = self.store.store_state_group(
group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
group_name = self.store.store_state_group(
group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
@defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
sg1 = self.store.store_state_group(
sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = self.store.store_state_group(
sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event)
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result