Use the chain cover index in get_auth_chain_ids. (#9576)

This uses a simplified version of get_chain_cover_difference to calculate
auth chain of events.
This commit is contained in:
Patrick Cloke 2021-03-10 09:57:59 -05:00 committed by GitHub
parent 918f6ed827
commit 2a99cc6524
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 226 additions and 11 deletions

View file

@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
An awaitable which resolve to a list of event_ids
list of event_ids
"""
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
self._get_auth_chain_ids_using_cover_index_txn,
room_id,
event_ids,
include_given,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events.
initial_events = set(event_ids)
# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]
# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
seen_events.add(event_id)
event_chains[chain_id] = max(
sequence_number, event_chains.get(chain_id, 0)
)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(seen_events)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
if include_given:
results = initial_events
else:
results = set()
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND sequence_number <= ?
"""
for chain_id, max_no in chains.items():
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)
return list(results)
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
"""
if include_given:
results = set(event_ids)
else: