mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-23 07:54:07 -04:00
Use a chain cover index to efficiently calculate auth chain difference (#8868)
This commit is contained in:
parent
671138f658
commit
1315a2e8be
14 changed files with 1777 additions and 56 deletions
|
@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
|||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import Collection
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _NoChainCoverIndex(Exception):
|
||||
def __init__(self, room_id: str):
|
||||
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
|
||||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
The set of the difference in auth chains.
|
||||
"""
|
||||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id)
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference_chains",
|
||||
self._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
state_sets,
|
||||
)
|
||||
except _NoChainCoverIndex:
|
||||
# For whatever reason we don't actually have a chain cover index
|
||||
# for the events in question, so we fall back to the old method.
|
||||
pass
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference",
|
||||
self._get_auth_chain_difference_txn,
|
||||
state_sets,
|
||||
)
|
||||
|
||||
def _get_auth_chain_difference_using_cover_index_txn(
|
||||
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using the chain index.
|
||||
|
||||
See docs/auth_chain_difference_algorithm.md for details
|
||||
"""
|
||||
|
||||
# First we look up the chain ID/sequence numbers for all the events, and
|
||||
# work out the chain/sequence numbers reachable from each state set.
|
||||
|
||||
initial_events = set(state_sets[0]).union(*state_sets[1:])
|
||||
|
||||
# Map from event_id -> (chain ID, seq no)
|
||||
chain_info = {} # type: Dict[str, Tuple[int, int]]
|
||||
|
||||
# Map from chain ID -> seq no -> event Id
|
||||
chain_to_event = {} # type: Dict[int, Dict[int, str]]
|
||||
|
||||
# All the chains that we've found that are reachable from the state
|
||||
# sets.
|
||||
seen_chains = set() # type: Set[int]
|
||||
|
||||
sql = """
|
||||
SELECT event_id, chain_id, sequence_number
|
||||
FROM event_auth_chains
|
||||
WHERE %s
|
||||
"""
|
||||
for batch in batch_iter(initial_events, 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "event_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
for event_id, chain_id, sequence_number in txn:
|
||||
chain_info[event_id] = (chain_id, sequence_number)
|
||||
seen_chains.add(chain_id)
|
||||
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
|
||||
|
||||
# Check that we actually have a chain ID for all the events.
|
||||
events_missing_chain_info = initial_events.difference(chain_info)
|
||||
if events_missing_chain_info:
|
||||
# This can happen due to e.g. downgrade/upgrade of the server. We
|
||||
# raise an exception and fall back to the previous algorithm.
|
||||
logger.info(
|
||||
"Unexpectedly found that events don't have chain IDs in room %s: %s",
|
||||
room_id,
|
||||
events_missing_chain_info,
|
||||
)
|
||||
raise _NoChainCoverIndex(room_id)
|
||||
|
||||
# Corresponds to `state_sets`, except as a map from chain ID to max
|
||||
# sequence number reachable from the state set.
|
||||
set_to_chain = [] # type: List[Dict[int, int]]
|
||||
for state_set in state_sets:
|
||||
chains = {} # type: Dict[int, int]
|
||||
set_to_chain.append(chains)
|
||||
|
||||
for event_id in state_set:
|
||||
chain_id, seq_no = chain_info[event_id]
|
||||
|
||||
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
|
||||
|
||||
# Now we look up all links for the chains we have, adding chains to
|
||||
# set_to_chain that are reachable from each set.
|
||||
sql = """
|
||||
SELECT
|
||||
origin_chain_id, origin_sequence_number,
|
||||
target_chain_id, target_sequence_number
|
||||
FROM event_auth_chain_links
|
||||
WHERE %s
|
||||
"""
|
||||
|
||||
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
||||
# the loop)
|
||||
for batch in batch_iter(set(seen_chains), 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
for (
|
||||
origin_chain_id,
|
||||
origin_sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
) in txn:
|
||||
for chains in set_to_chain:
|
||||
# chains are only reachable if the origin sequence number of
|
||||
# the link is less than the max sequence number in the
|
||||
# origin chain.
|
||||
if origin_sequence_number <= chains.get(origin_chain_id, 0):
|
||||
chains[target_chain_id] = max(
|
||||
target_sequence_number, chains.get(target_chain_id, 0),
|
||||
)
|
||||
|
||||
seen_chains.add(target_chain_id)
|
||||
|
||||
# Now for each chain we figure out the maximum sequence number reachable
|
||||
# from *any* state set and the minimum sequence number reachable from
|
||||
# *all* state sets. Events in that range are in the auth chain
|
||||
# difference.
|
||||
result = set()
|
||||
|
||||
# Mapping from chain ID to the range of sequence numbers that should be
|
||||
# pulled from the database.
|
||||
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
|
||||
|
||||
for chain_id in seen_chains:
|
||||
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||
|
||||
if min_seq_no < max_seq_no:
|
||||
# We have a non empty gap, try and fill it from the events that
|
||||
# we have, otherwise add them to the list of gaps to pull out
|
||||
# from the DB.
|
||||
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
|
||||
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
|
||||
if event_id:
|
||||
result.add(event_id)
|
||||
else:
|
||||
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
|
||||
break
|
||||
|
||||
if not chain_to_gap:
|
||||
# If there are no gaps to fetch, we're done!
|
||||
return result
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
|
||||
WHERE
|
||||
c.chain_id = l.chain_id
|
||||
AND min_seq < sequence_number AND sequence_number <= max_seq
|
||||
"""
|
||||
|
||||
args = [
|
||||
(chain_id, min_no, max_no)
|
||||
for chain_id, (min_no, max_no) in chain_to_gap.items()
|
||||
]
|
||||
|
||||
rows = txn.execute_values(sql, args)
|
||||
result.update(r for r, in rows)
|
||||
else:
|
||||
# For SQLite we just fall back to doing a noddy for loop.
|
||||
sql = """
|
||||
SELECT event_id FROM event_auth_chains
|
||||
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
|
||||
"""
|
||||
for chain_id, (min_no, max_no) in chain_to_gap.items():
|
||||
txn.execute(sql, (chain_id, min_no, max_no))
|
||||
result.update(r for r, in txn)
|
||||
|
||||
return result
|
||||
|
||||
def _get_auth_chain_difference_txn(
|
||||
self, txn, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using a breadth first search.
|
||||
|
||||
This is used when we don't have a cover index for the room.
|
||||
"""
|
||||
|
||||
# Algorithm Description
|
||||
# ~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue