Use a chain cover index to efficiently calculate auth chain difference (#8868)

This commit is contained in:
Erik Johnston 2021-01-11 16:09:22 +00:00 committed by GitHub
parent 671138f658
commit 1315a2e8be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1777 additions and 56 deletions

View file

@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains.
"""
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
"""
# First we look up the chain ID/sequence numbers for all the events, and
# work out the chain/sequence numbers reachable from each state set.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]]
# Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]]
# All the chains that we've found that are reachable from the state
# sets.
seen_chains = set() # type: Set[int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]]
for state_set in state_sets:
chains = {} # type: Dict[int, int]
set_to_chain.append(chains)
for event_id in state_set:
chain_id, seq_no = chain_info[event_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
# Now we look up all links for the chains we have, adding chains to
# set_to_chain that are reachable from each set.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
for chains in set_to_chain:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number, chains.get(target_chain_id, 0),
)
seen_chains.add(target_chain_id)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
if min_seq_no < max_seq_no:
# We have a non empty gap, try and fill it from the events that
# we have, otherwise add them to the list of gaps to pull out
# from the DB.
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
if event_id:
result.add(event_id)
else:
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
break
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
return result
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
rows = txn.execute_values(sql, args)
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
"""
for chain_id, (min_no, max_no) in chain_to_gap.items():
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn)
return result
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
This is used when we don't have a cover index for the room.
"""
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~