mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 04:34:19 -05:00
Use the chain cover index in get_auth_chain_ids. (#9576)
This uses a simplified version of get_chain_cover_difference to calculate auth chain of events.
This commit is contained in:
parent
918f6ed827
commit
2a99cc6524
1
changelog.d/9576.misc
Normal file
1
changelog.d/9576.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve efficiency of calculating the auth chain in large rooms.
|
@ -447,7 +447,7 @@ class FederationServer(FederationBase):
|
|||||||
|
|
||||||
async def _on_state_ids_request_compute(self, room_id, event_id):
|
async def _on_state_ids_request_compute(self, room_id, event_id):
|
||||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||||
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
|
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
||||||
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||||
|
|
||||||
async def _on_context_state_request_compute(
|
async def _on_context_state_request_compute(
|
||||||
@ -460,7 +460,9 @@ class FederationServer(FederationBase):
|
|||||||
else:
|
else:
|
||||||
pdus = (await self.state.get_current_state(room_id)).values()
|
pdus = (await self.state.get_current_state(room_id)).values()
|
||||||
|
|
||||||
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
|
auth_chain = await self.store.get_auth_chain(
|
||||||
|
room_id, [pdu.event_id for pdu in pdus]
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||||
|
@ -1317,7 +1317,7 @@ class FederationHandler(BaseHandler):
|
|||||||
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
||||||
event = await self.store.get_event(event_id)
|
event = await self.store.get_event(event_id)
|
||||||
auth = await self.store.get_auth_chain(
|
auth = await self.store.get_auth_chain(
|
||||||
list(event.auth_event_ids()), include_given=True
|
event.room_id, list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
return list(auth)
|
return list(auth)
|
||||||
|
|
||||||
@ -1580,7 +1580,7 @@ class FederationHandler(BaseHandler):
|
|||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
|
||||||
state_ids = list(prev_state_ids.values())
|
state_ids = list(prev_state_ids.values())
|
||||||
auth_chain = await self.store.get_auth_chain(state_ids)
|
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
|
||||||
|
|
||||||
state = await self.store.get_events(list(prev_state_ids.values()))
|
state = await self.store.get_events(list(prev_state_ids.values()))
|
||||||
|
|
||||||
@ -2219,7 +2219,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
# Now get the current auth_chain for the event.
|
||||||
local_auth_chain = await self.store.get_auth_chain(
|
local_auth_chain = await self.store.get_auth_chain(
|
||||||
list(event.auth_event_ids()), include_given=True
|
room_id, list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||||
|
@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
) # type: LruCache[str, List[Tuple[str, int]]]
|
) # type: LruCache[str, List[Tuple[str, int]]]
|
||||||
|
|
||||||
async def get_auth_chain(
|
async def get_auth_chain(
|
||||||
self, event_ids: Collection[str], include_given: bool = False
|
self, room_id: str, event_ids: Collection[str], include_given: bool = False
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
room_id: The room the event is in.
|
||||||
event_ids: state events
|
event_ids: state events
|
||||||
include_given: include the given events in result
|
include_given: include the given events in result
|
||||||
|
|
||||||
@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
list of events
|
list of events
|
||||||
"""
|
"""
|
||||||
event_ids = await self.get_auth_chain_ids(
|
event_ids = await self.get_auth_chain_ids(
|
||||||
event_ids, include_given=include_given
|
room_id, event_ids, include_given=include_given
|
||||||
)
|
)
|
||||||
return await self.get_events_as_list(event_ids)
|
return await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
async def get_auth_chain_ids(
|
async def get_auth_chain_ids(
|
||||||
self,
|
self,
|
||||||
|
room_id: str,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
include_given: bool = False,
|
include_given: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
room_id: The room the event is in.
|
||||||
event_ids: state events
|
event_ids: state events
|
||||||
include_given: include the given events in result
|
include_given: include the given events in result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An awaitable which resolve to a list of event_ids
|
list of event_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
|
# algorithm.
|
||||||
|
room = await self.get_room(room_id)
|
||||||
|
if room["has_auth_chain_index"]:
|
||||||
|
try:
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_auth_chain_ids_chains",
|
||||||
|
self._get_auth_chain_ids_using_cover_index_txn,
|
||||||
|
room_id,
|
||||||
|
event_ids,
|
||||||
|
include_given,
|
||||||
|
)
|
||||||
|
except _NoChainCoverIndex:
|
||||||
|
# For whatever reason we don't actually have a chain cover index
|
||||||
|
# for the events in question, so we fall back to the old method.
|
||||||
|
pass
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_auth_chain_ids",
|
"get_auth_chain_ids",
|
||||||
self._get_auth_chain_ids_txn,
|
self._get_auth_chain_ids_txn,
|
||||||
@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
include_given,
|
include_given,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_auth_chain_ids_using_cover_index_txn(
|
||||||
|
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
||||||
|
) -> List[str]:
|
||||||
|
"""Calculates the auth chain IDs using the chain index."""
|
||||||
|
|
||||||
|
# First we look up the chain ID/sequence numbers for the given events.
|
||||||
|
|
||||||
|
initial_events = set(event_ids)
|
||||||
|
|
||||||
|
# All the events that we've found that are reachable from the events.
|
||||||
|
seen_events = set() # type: Set[str]
|
||||||
|
|
||||||
|
# A map from chain ID to max sequence number of the given events.
|
||||||
|
event_chains = {} # type: Dict[int, int]
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT event_id, chain_id, sequence_number
|
||||||
|
FROM event_auth_chains
|
||||||
|
WHERE %s
|
||||||
|
"""
|
||||||
|
for batch in batch_iter(initial_events, 1000):
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "event_id", batch
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
for event_id, chain_id, sequence_number in txn:
|
||||||
|
seen_events.add(event_id)
|
||||||
|
event_chains[chain_id] = max(
|
||||||
|
sequence_number, event_chains.get(chain_id, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we actually have a chain ID for all the events.
|
||||||
|
events_missing_chain_info = initial_events.difference(seen_events)
|
||||||
|
if events_missing_chain_info:
|
||||||
|
# This can happen due to e.g. downgrade/upgrade of the server. We
|
||||||
|
# raise an exception and fall back to the previous algorithm.
|
||||||
|
logger.info(
|
||||||
|
"Unexpectedly found that events don't have chain IDs in room %s: %s",
|
||||||
|
room_id,
|
||||||
|
events_missing_chain_info,
|
||||||
|
)
|
||||||
|
raise _NoChainCoverIndex(room_id)
|
||||||
|
|
||||||
|
# Now we look up all links for the chains we have, adding chains that
|
||||||
|
# are reachable from any event.
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
origin_chain_id, origin_sequence_number,
|
||||||
|
target_chain_id, target_sequence_number
|
||||||
|
FROM event_auth_chain_links
|
||||||
|
WHERE %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A map from chain ID to max sequence number *reachable* from any event ID.
|
||||||
|
chains = {} # type: Dict[int, int]
|
||||||
|
|
||||||
|
# Add all linked chains reachable from initial set of chains.
|
||||||
|
for batch in batch_iter(event_chains, 1000):
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "origin_chain_id", batch
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
for (
|
||||||
|
origin_chain_id,
|
||||||
|
origin_sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
) in txn:
|
||||||
|
# chains are only reachable if the origin sequence number of
|
||||||
|
# the link is less than the max sequence number in the
|
||||||
|
# origin chain.
|
||||||
|
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
|
||||||
|
chains[target_chain_id] = max(
|
||||||
|
target_sequence_number,
|
||||||
|
chains.get(target_chain_id, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the initial set of chains, excluding the sequence corresponding to
|
||||||
|
# initial event.
|
||||||
|
for chain_id, seq_no in event_chains.items():
|
||||||
|
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
|
||||||
|
|
||||||
|
# Now for each chain we figure out the maximum sequence number reachable
|
||||||
|
# from *any* event ID. Events with a sequence less than that are in the
|
||||||
|
# auth chain.
|
||||||
|
if include_given:
|
||||||
|
results = initial_events
|
||||||
|
else:
|
||||||
|
results = set()
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# We can use `execute_values` to efficiently fetch the gaps when
|
||||||
|
# using postgres.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id
|
||||||
|
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||||
|
WHERE
|
||||||
|
c.chain_id = l.chain_id
|
||||||
|
AND sequence_number <= max_seq
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = txn.execute_values(sql, chains.items())
|
||||||
|
results.update(r for r, in rows)
|
||||||
|
else:
|
||||||
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id FROM event_auth_chains
|
||||||
|
WHERE chain_id = ? AND sequence_number <= ?
|
||||||
|
"""
|
||||||
|
for chain_id, max_no in chains.items():
|
||||||
|
txn.execute(sql, (chain_id, max_no))
|
||||||
|
results.update(r for r, in txn)
|
||||||
|
|
||||||
|
return list(results)
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(
|
def _get_auth_chain_ids_txn(
|
||||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""Calculates the auth chain IDs.
|
||||||
|
|
||||||
|
This is used when we don't have a cover index for the room.
|
||||||
|
"""
|
||||||
if include_given:
|
if include_given:
|
||||||
results = set(event_ids)
|
results = set(event_ids)
|
||||||
else:
|
else:
|
||||||
|
@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
|
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
|
||||||
self.assertTrue(r == [room2] or r == [room3])
|
self.assertTrue(r == [room2] or r == [room3])
|
||||||
|
|
||||||
@parameterized.expand([(True,), (False,)])
|
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
|
||||||
def test_auth_difference(self, use_chain_cover_index: bool):
|
|
||||||
room_id = "@ROOM:local"
|
room_id = "@ROOM:local"
|
||||||
|
|
||||||
# The silly auth graph we use to test the auth difference algorithm,
|
# The silly auth graph we use to test the auth difference algorithm,
|
||||||
@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
"j": 1,
|
"j": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mark the room as not having a cover index
|
# Mark the room as maybe having a cover index.
|
||||||
|
|
||||||
def store_room(txn):
|
def store_room(txn):
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return room_id
|
||||||
|
|
||||||
|
@parameterized.expand([(True,), (False,)])
|
||||||
|
def test_auth_chain_ids(self, use_chain_cover_index: bool):
|
||||||
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
|
# a and b have the same auth chain.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["a", "b"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
# d and e have the same auth chain.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
|
||||||
|
self.assertEqual(auth_chain_ids, ["k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
|
||||||
|
self.assertEqual(auth_chain_ids, ["j"])
|
||||||
|
|
||||||
|
# j and k have no parents.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
|
||||||
|
self.assertEqual(auth_chain_ids, [])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
|
||||||
|
self.assertEqual(auth_chain_ids, [])
|
||||||
|
|
||||||
|
# More complex input sequences.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["h", "i"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["k", "j"])
|
||||||
|
|
||||||
|
# e gets returned even though include_given is false, but it is in the
|
||||||
|
# auth chain of b.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["b", "e"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
# Test include_given.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["i", "j"])
|
||||||
|
|
||||||
|
@parameterized.expand([(True,), (False,)])
|
||||||
|
def test_auth_difference(self, use_chain_cover_index: bool):
|
||||||
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
# Now actually test that various combinations give the right result:
|
# Now actually test that various combinations give the right result:
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
|
Loading…
Reference in New Issue
Block a user