mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Pass room_id to get_auth_chain_difference (#8879)
This is so that we can choose which algorithm to use based on the room ID.
This commit is contained in:
parent
b774c555d8
commit
df4b1e9c74
1
changelog.d/8879.misc
Normal file
1
changelog.d/8879.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Pass `room_id` to `get_auth_chain_difference`.
|
@ -783,7 +783,7 @@ class StateResolutionStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_auth_chain_difference(
|
def get_auth_chain_difference(
|
||||||
self, state_sets: List[Set[str]]
|
self, room_id: str, state_sets: List[Set[str]]
|
||||||
) -> Awaitable[Set[str]]:
|
) -> Awaitable[Set[str]]:
|
||||||
"""Given sets of state events figure out the auth chain difference (as
|
"""Given sets of state events figure out the auth chain difference (as
|
||||||
per state res v2 algorithm).
|
per state res v2 algorithm).
|
||||||
@ -796,4 +796,4 @@ class StateResolutionStore:
|
|||||||
An awaitable that resolves to a set of event IDs.
|
An awaitable that resolves to a set of event IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.store.get_auth_chain_difference(state_sets)
|
return self.store.get_auth_chain_difference(room_id, state_sets)
|
||||||
|
@ -97,7 +97,9 @@ async def resolve_events_with_store(
|
|||||||
|
|
||||||
# Also fetch all auth events that appear in only some of the state sets'
|
# Also fetch all auth events that appear in only some of the state sets'
|
||||||
# auth chains.
|
# auth chains.
|
||||||
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
|
auth_diff = await _get_auth_chain_difference(
|
||||||
|
room_id, state_sets, event_map, state_res_store
|
||||||
|
)
|
||||||
|
|
||||||
full_conflicted_set = set(
|
full_conflicted_set = set(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
|
|||||||
|
|
||||||
|
|
||||||
async def _get_auth_chain_difference(
|
async def _get_auth_chain_difference(
|
||||||
|
room_id: str,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: "synapse.state.StateResolutionStore",
|
||||||
@ -332,7 +335,9 @@ async def _get_auth_chain_difference(
|
|||||||
difference_from_event_map = ()
|
difference_from_event_map = ()
|
||||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||||
|
|
||||||
difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
|
difference = await state_res_store.get_auth_chain_difference(
|
||||||
|
room_id, state_sets_ids
|
||||||
|
)
|
||||||
difference.update(difference_from_event_map)
|
difference.update(difference_from_event_map)
|
||||||
|
|
||||||
return difference
|
return difference
|
||||||
|
@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
|
|
||||||
return list(results)
|
return list(results)
|
||||||
|
|
||||||
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
|
async def get_auth_chain_difference(
|
||||||
|
self, room_id: str, state_sets: List[Set[str]]
|
||||||
|
) -> Set[str]:
|
||||||
"""Given sets of state events figure out the auth chain difference (as
|
"""Given sets of state events figure out the auth chain difference (as
|
||||||
per state res v2 algorithm).
|
per state res v2 algorithm).
|
||||||
|
|
||||||
|
@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
store = TestStateResolutionStore(persisted_events)
|
store = TestStateResolutionStore(persisted_events)
|
||||||
|
|
||||||
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
|
diff_d = _get_auth_chain_difference(
|
||||||
|
ROOM_ID, state_sets, unpersited_events, store
|
||||||
|
)
|
||||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||||
|
|
||||||
self.assertEqual(difference, {c.event_id})
|
self.assertEqual(difference, {c.event_id})
|
||||||
@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
store = TestStateResolutionStore(persisted_events)
|
store = TestStateResolutionStore(persisted_events)
|
||||||
|
|
||||||
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
|
diff_d = _get_auth_chain_difference(
|
||||||
|
ROOM_ID, state_sets, unpersited_events, store
|
||||||
|
)
|
||||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||||
|
|
||||||
self.assertEqual(difference, {d.event_id, c.event_id})
|
self.assertEqual(difference, {d.event_id, c.event_id})
|
||||||
@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
store = TestStateResolutionStore(persisted_events)
|
store = TestStateResolutionStore(persisted_events)
|
||||||
|
|
||||||
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
|
diff_d = _get_auth_chain_difference(
|
||||||
|
ROOM_ID, state_sets, unpersited_events, store
|
||||||
|
)
|
||||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||||
|
|
||||||
self.assertEqual(difference, {d.event_id, e.event_id})
|
self.assertEqual(difference, {d.event_id, e.event_id})
|
||||||
@ -773,7 +779,7 @@ class TestStateResolutionStore:
|
|||||||
|
|
||||||
return list(result)
|
return list(result)
|
||||||
|
|
||||||
def get_auth_chain_difference(self, auth_sets):
|
def get_auth_chain_difference(self, room_id, auth_sets):
|
||||||
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
||||||
|
|
||||||
common = set(chains[0]).intersection(*chains[1:])
|
common = set(chains[0]).intersection(*chains[1:])
|
||||||
|
@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
# Now actually test that various combinations give the right result:
|
# Now actually test that various combinations give the right result:
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a"}, {"b"}])
|
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b"})
|
self.assertSetEqual(difference, {"a", "b"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
|
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
|
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
|
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b", "c"})
|
self.assertSetEqual(difference, {"a", "b", "c"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
|
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b"})
|
self.assertSetEqual(difference, {"a", "b"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
|
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b", "d", "e"})
|
self.assertSetEqual(difference, {"a", "b", "d", "e"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
|
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
|
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
|
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
|
||||||
)
|
)
|
||||||
self.assertSetEqual(difference, {"a", "b"})
|
self.assertSetEqual(difference, {"a", "b"})
|
||||||
|
|
||||||
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
|
difference = self.get_success(
|
||||||
|
self.store.get_auth_chain_difference(room_id, [{"a"}])
|
||||||
|
)
|
||||||
self.assertSetEqual(difference, set())
|
self.assertSetEqual(difference, set())
|
||||||
|
Loading…
Reference in New Issue
Block a user