mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-07 01:54:57 -04:00
Use the chain cover index in get_auth_chain_ids. (#9576)
This uses a simplified version of get_chain_cover_difference to calculate auth chain of events.
This commit is contained in:
parent
918f6ed827
commit
2a99cc6524
5 changed files with 226 additions and 11 deletions
|
@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
) # type: LruCache[str, List[Tuple[str, int]]]
|
||||
|
||||
async def get_auth_chain(
|
||||
self, event_ids: Collection[str], include_given: bool = False
|
||||
self, room_id: str, event_ids: Collection[str], include_given: bool = False
|
||||
) -> List[EventBase]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
room_id: The room the event is in.
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
|
||||
|
@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
list of events
|
||||
"""
|
||||
event_ids = await self.get_auth_chain_ids(
|
||||
event_ids, include_given=include_given
|
||||
room_id, event_ids, include_given=include_given
|
||||
)
|
||||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
async def get_auth_chain_ids(
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
include_given: bool = False,
|
||||
) -> List[str]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
room_id: The room the event is in.
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
|
||||
Returns:
|
||||
An awaitable which resolve to a list of event_ids
|
||||
list of event_ids
|
||||
"""
|
||||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id)
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids_chains",
|
||||
self._get_auth_chain_ids_using_cover_index_txn,
|
||||
room_id,
|
||||
event_ids,
|
||||
include_given,
|
||||
)
|
||||
except _NoChainCoverIndex:
|
||||
# For whatever reason we don't actually have a chain cover index
|
||||
# for the events in question, so we fall back to the old method.
|
||||
pass
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
|
@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
include_given,
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_using_cover_index_txn(
|
||||
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
||||
) -> List[str]:
|
||||
"""Calculates the auth chain IDs using the chain index."""
|
||||
|
||||
# First we look up the chain ID/sequence numbers for the given events.
|
||||
|
||||
initial_events = set(event_ids)
|
||||
|
||||
# All the events that we've found that are reachable from the events.
|
||||
seen_events = set() # type: Set[str]
|
||||
|
||||
# A map from chain ID to max sequence number of the given events.
|
||||
event_chains = {} # type: Dict[int, int]
|
||||
|
||||
sql = """
|
||||
SELECT event_id, chain_id, sequence_number
|
||||
FROM event_auth_chains
|
||||
WHERE %s
|
||||
"""
|
||||
for batch in batch_iter(initial_events, 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "event_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
for event_id, chain_id, sequence_number in txn:
|
||||
seen_events.add(event_id)
|
||||
event_chains[chain_id] = max(
|
||||
sequence_number, event_chains.get(chain_id, 0)
|
||||
)
|
||||
|
||||
# Check that we actually have a chain ID for all the events.
|
||||
events_missing_chain_info = initial_events.difference(seen_events)
|
||||
if events_missing_chain_info:
|
||||
# This can happen due to e.g. downgrade/upgrade of the server. We
|
||||
# raise an exception and fall back to the previous algorithm.
|
||||
logger.info(
|
||||
"Unexpectedly found that events don't have chain IDs in room %s: %s",
|
||||
room_id,
|
||||
events_missing_chain_info,
|
||||
)
|
||||
raise _NoChainCoverIndex(room_id)
|
||||
|
||||
# Now we look up all links for the chains we have, adding chains that
|
||||
# are reachable from any event.
|
||||
sql = """
|
||||
SELECT
|
||||
origin_chain_id, origin_sequence_number,
|
||||
target_chain_id, target_sequence_number
|
||||
FROM event_auth_chain_links
|
||||
WHERE %s
|
||||
"""
|
||||
|
||||
# A map from chain ID to max sequence number *reachable* from any event ID.
|
||||
chains = {} # type: Dict[int, int]
|
||||
|
||||
# Add all linked chains reachable from initial set of chains.
|
||||
for batch in batch_iter(event_chains, 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
for (
|
||||
origin_chain_id,
|
||||
origin_sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
) in txn:
|
||||
# chains are only reachable if the origin sequence number of
|
||||
# the link is less than the max sequence number in the
|
||||
# origin chain.
|
||||
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
|
||||
chains[target_chain_id] = max(
|
||||
target_sequence_number,
|
||||
chains.get(target_chain_id, 0),
|
||||
)
|
||||
|
||||
# Add the initial set of chains, excluding the sequence corresponding to
|
||||
# initial event.
|
||||
for chain_id, seq_no in event_chains.items():
|
||||
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
|
||||
|
||||
# Now for each chain we figure out the maximum sequence number reachable
|
||||
# from *any* event ID. Events with a sequence less than that are in the
|
||||
# auth chain.
|
||||
if include_given:
|
||||
results = initial_events
|
||||
else:
|
||||
results = set()
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND sequence_number <= max_seq
|
||||
"""
|
||||
|
||||
rows = txn.execute_values(sql, chains.items())
|
||||
results.update(r for r, in rows)
|
||||
else:
|
||||
# For SQLite we just fall back to doing a noddy for loop.
|
||||
sql = """
|
||||
SELECT event_id FROM event_auth_chains
|
||||
WHERE chain_id = ? AND sequence_number <= ?
|
||||
"""
|
||||
for chain_id, max_no in chains.items():
|
||||
txn.execute(sql, (chain_id, max_no))
|
||||
results.update(r for r, in txn)
|
||||
|
||||
return list(results)
|
||||
|
||||
def _get_auth_chain_ids_txn(
|
||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||
) -> List[str]:
|
||||
"""Calculates the auth chain IDs.
|
||||
|
||||
This is used when we don't have a cover index for the room.
|
||||
"""
|
||||
if include_given:
|
||||
results = set(event_ids)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue