mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-03 02:44:53 -04:00
Replace context.current_state with context.current_state_ids
This commit is contained in:
parent
17f4f14df7
commit
a3dc1e9cbe
15 changed files with 435 additions and 270 deletions
|
@ -69,7 +69,7 @@ class StateGroupStore(object):
|
|||
|
||||
self._next_group = 1
|
||||
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
groups = {}
|
||||
for event_id in event_ids:
|
||||
group = self._event_to_state_group.get(event_id)
|
||||
|
@ -79,20 +79,20 @@ class StateGroupStore(object):
|
|||
return defer.succeed(groups)
|
||||
|
||||
def store_state_groups(self, event, context):
|
||||
if context.current_state is None:
|
||||
if context.current_state_ids is None:
|
||||
return
|
||||
|
||||
state_events = context.current_state
|
||||
state_events = dict(context.current_state_ids)
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
state_events[(event.type, event.state_key)] = event.event_id
|
||||
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
state_group = self._next_group
|
||||
self._next_group += 1
|
||||
|
||||
self._group_to_state[state_group] = state_events.values()
|
||||
self._group_to_state[state_group] = state_events
|
||||
|
||||
self._event_to_state_group[event.event_id] = state_group
|
||||
|
||||
|
@ -136,7 +136,7 @@ class StateTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.store = Mock(
|
||||
spec_set=[
|
||||
"get_state_groups",
|
||||
"get_state_groups_ids",
|
||||
"add_event_hashes",
|
||||
]
|
||||
)
|
||||
|
@ -187,7 +187,7 @@ class StateTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -196,7 +196,7 @@ class StateTestCase(unittest.TestCase):
|
|||
store.store_state_groups(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
self.assertEqual(2, len(context_store["D"].current_state))
|
||||
self.assertEqual(2, len(context_store["D"].current_state_ids))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_branch_basic_conflict(self):
|
||||
|
@ -239,7 +239,7 @@ class StateTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -303,7 +303,7 @@ class StateTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -384,7 +384,7 @@ class StateTestCase(unittest.TestCase):
|
|||
graph = Graph(nodes, edges)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
|
||||
context_store = {}
|
||||
|
||||
|
@ -424,13 +424,8 @@ class StateTestCase(unittest.TestCase):
|
|||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state), set(context.current_state.values())
|
||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
@ -449,14 +444,8 @@ class StateTestCase(unittest.TestCase):
|
|||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state),
|
||||
set(context.current_state.values())
|
||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
@ -473,20 +462,15 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||
}
|
||||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, context.state_group)
|
||||
|
@ -503,20 +487,15 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||
}
|
||||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
@ -545,7 +524,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state), 6)
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
|
@ -573,7 +552,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state), 6)
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
|
@ -608,7 +587,7 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
|
||||
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
|
||||
|
||||
# Reverse the depth to make sure we are actually using the depths
|
||||
# during state resolution.
|
||||
|
@ -627,15 +606,15 @@ class StateTestCase(unittest.TestCase):
|
|||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
|
||||
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
|
||||
|
||||
def _get_context(self, event, old_state_1, old_state_2):
|
||||
group_name_1 = "group_name_1"
|
||||
group_name_2 = "group_name_2"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name_1: old_state_1,
|
||||
group_name_2: old_state_2,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||
}
|
||||
|
||||
return self.state.compute_event_context(event)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue