Optimise backfill calculation (#12522)

Try to avoid an OOM by checking fewer extremities.

Generally this is a big rewrite of _maybe_backfill, to try and fix some of the TODOs and other problems in it. It's best reviewed commit-by-commit.
This commit is contained in:
Richard van der Hoff 2022-04-26 10:27:11 +01:00 committed by GitHub
parent e75c7e3b6d
commit 17d99f758a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 168 additions and 106 deletions

1
changelog.d/12522.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 0.99.3 which could cause Synapse to consume large amounts of RAM when back-paginating in a large room.

View File

@ -1,4 +1,4 @@
# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome # Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -15,10 +15,14 @@
"""Contains handlers for federation events.""" """Contains handlers for federation events."""
import enum
import itertools
import logging import logging
from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -92,6 +96,24 @@ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
return sorted(joined_domains.items(), key=lambda d: d[1]) return sorted(joined_domains.items(), key=lambda d: d[1])
class _BackfillPointType(Enum):
# a regular backwards extremity (ie, an event which we don't yet have, but which
# is referred to by other events in the DAG)
BACKWARDS_EXTREMITY = enum.auto()
# an MSC2716 "insertion event"
INSERTION_PONT = enum.auto()
@attr.s(slots=True, auto_attribs=True, frozen=True)
class _BackfillPoint:
"""A potential point we might backfill from"""
event_id: str
depth: int
type: _BackfillPointType
class FederationHandler: class FederationHandler:
"""Handles general incoming federation requests """Handles general incoming federation requests
@ -157,89 +179,51 @@ class FederationHandler:
async def _maybe_backfill_inner( async def _maybe_backfill_inner(
self, room_id: str, current_depth: int, limit: int self, room_id: str, current_depth: int, limit: int
) -> bool: ) -> bool:
oldest_events_with_depth = ( backwards_extremities = [
await self.store.get_oldest_event_ids_with_depth_in_room(room_id) _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
) for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
insertion_events_to_be_backfilled: Dict[str, int] = {}
if self.hs.config.experimental.msc2716_enabled:
insertion_events_to_be_backfilled = (
await self.store.get_insertion_event_backward_extremities_in_room(
room_id room_id
) )
]
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
if self.hs.config.experimental.msc2716_enabled:
insertion_events_to_be_backfilled = [
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
room_id
) )
]
logger.debug( logger.debug(
"_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", "_maybe_backfill_inner: backwards_extremities=%s insertion_events_to_be_backfilled=%s",
oldest_events_with_depth, backwards_extremities,
insertion_events_to_be_backfilled, insertion_events_to_be_backfilled,
) )
if not oldest_events_with_depth and not insertion_events_to_be_backfilled: if not backwards_extremities and not insertion_events_to_be_backfilled:
logger.debug("Not backfilling as no extremeties found.") logger.debug("Not backfilling as no extremeties found.")
return False return False
# We only want to paginate if we can actually see the events we'll get, # we now have a list of potential places to backpaginate from. We prefer to
# as otherwise we'll just spend a lot of resources to get redacted # start with the most recent (ie, max depth), so let's sort the list.
# events. sorted_backfill_points: List[_BackfillPoint] = sorted(
# itertools.chain(
# We do this by filtering all the backwards extremities and seeing if backwards_extremities,
# any remain. Given we don't have the extremity events themselves, we insertion_events_to_be_backfilled,
# need to actually check the events that reference them. ),
# key=lambda e: -int(e.depth),
# *Note*: the spec wants us to keep backfilling until we reach the start
# of the room in case we are allowed to see some of the history. However
# in practice that causes more issues than its worth, as a) its
# relatively rare for there to be any visible history and b) even when
# there is its often sufficiently long ago that clients would stop
# attempting to paginate before backfill reached the visible history.
#
# TODO: If we do do a backfill then we should filter the backwards
# extremities to only include those that point to visible portions of
# history.
#
# TODO: Correctly handle the case where we are allowed to see the
# forward event but not the backward extremity, e.g. in the case of
# initial join of the server where we are allowed to see the join
# event but not anything before it. This would require looking at the
# state *before* the event, ignoring the special casing certain event
# types have.
forward_event_ids = await self.store.get_successor_events(
list(oldest_events_with_depth)
) )
extremities_events = await self.store.get_events(
forward_event_ids,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
redact=False,
check_history_visibility_only=True,
)
logger.debug( logger.debug(
"_maybe_backfill_inner: filtered_extremities %s", filtered_extremities "_maybe_backfill_inner: room_id: %s: current_depth: %s, limit: %s, "
"backfill points (%d): %s",
room_id,
current_depth,
limit,
len(sorted_backfill_points),
sorted_backfill_points,
) )
if not filtered_extremities and not insertion_events_to_be_backfilled:
return False
extremities = {
**oldest_events_with_depth,
# TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks
**insertion_events_to_be_backfilled,
}
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
# If we're approaching an extremity we trigger a backfill, otherwise we # If we're approaching an extremity we trigger a backfill, otherwise we
# no-op. # no-op.
# #
@ -249,6 +233,11 @@ class FederationHandler:
# chose more than one times the limit in case of failure, but choosing a # chose more than one times the limit in case of failure, but choosing a
# much larger factor will result in triggering a backfill request much # much larger factor will result in triggering a backfill request much
# earlier than necessary. # earlier than necessary.
#
# XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
# care about events that have happened after our current position.
#
max_depth = sorted_backfill_points[0].depth
if current_depth - 2 * limit > max_depth: if current_depth - 2 * limit > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d - 2 * %d", "Not backfilling as we don't need to. %d < %d - 2 * %d",
@ -265,31 +254,98 @@ class FederationHandler:
# 2. we have likely previously tried and failed to backfill from that # 2. we have likely previously tried and failed to backfill from that
# extremity, so to avoid getting "stuck" requesting the same # extremity, so to avoid getting "stuck" requesting the same
# backfill repeatedly we drop those extremities. # backfill repeatedly we drop those extremities.
filtered_sorted_extremeties_tuple = [ #
t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
]
logger.debug(
"room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems (%d): %s filtered_sorted_extremeties_tuple: %s",
room_id,
current_depth,
limit,
max_depth,
len(sorted_extremeties_tuple),
sorted_extremeties_tuple,
filtered_sorted_extremeties_tuple,
)
# However, we need to check that the filtered extremities are non-empty. # However, we need to check that the filtered extremities are non-empty.
# If they are empty then either we can a) bail or b) still attempt to # If they are empty then either we can a) bail or b) still attempt to
# backfill. We opt to try backfilling anyway just in case we do get # backfill. We opt to try backfilling anyway just in case we do get
# relevant events. # relevant events.
if filtered_sorted_extremeties_tuple: #
sorted_extremeties_tuple = filtered_sorted_extremeties_tuple filtered_sorted_backfill_points = [
t for t in sorted_backfill_points if t.depth <= current_depth
]
if filtered_sorted_backfill_points:
logger.debug(
"_maybe_backfill_inner: backfill points before current depth: %s",
filtered_sorted_backfill_points,
)
sorted_backfill_points = filtered_sorted_backfill_points
else:
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
)
# We don't want to specify too many extremities as it causes the backfill # For performance's sake, we only want to paginate from a particular extremity
# request URI to be too long. # if we can actually see the events we'll get. Otherwise, we'd just spend a lot
extremities = dict(sorted_extremeties_tuple[:5]) # of resources to get redacted events. We check each extremity in turn and
# ignore those which users on our server wouldn't be able to see.
#
# Additionally, we limit ourselves to backfilling from at most 5 extremities,
# for two reasons:
#
# - The check which determines if we can see an extremity's events can be
# expensive (we load the full state for the room at each of the backfill
# points, or (worse) their successors)
# - We want to avoid the server-server API request URI becoming too long.
#
# *Note*: the spec wants us to keep backfilling until we reach the start
# of the room in case we are allowed to see some of the history. However,
# in practice that causes more issues than its worth, as (a) it's
# relatively rare for there to be any visible history and (b) even when
# there is it's often sufficiently long ago that clients would stop
# attempting to paginate before backfill reached the visible history.
extremities_to_request: List[str] = []
for bp in sorted_backfill_points:
if len(extremities_to_request) >= 5:
break
# For regular backwards extremities, we don't have the extremity events
# themselves, so we need to actually check the events that reference them -
# their "successor" events.
#
# TODO: Correctly handle the case where we are allowed to see the
# successor event but not the backward extremity, e.g. in the case of
# initial join of the server where we are allowed to see the join
# event but not anything before it. This would require looking at the
# state *before* the event, ignoring the special casing certain event
# types have.
if bp.type == _BackfillPointType.INSERTION_PONT:
event_ids_to_check = [bp.event_id]
else:
event_ids_to_check = await self.store.get_successor_events(bp.event_id)
events_to_check = await self.store.get_events_as_list(
event_ids_to_check,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
events_to_check,
redact=False,
check_history_visibility_only=True,
)
if filtered_extremities:
extremities_to_request.append(bp.event_id)
else:
logger.debug(
"_maybe_backfill_inner: skipping extremity %s as it would not be visible",
bp,
)
if not extremities_to_request:
logger.debug(
"_maybe_backfill_inner: found no extremities which would be visible"
)
return False
logger.debug(
"_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
)
# Now we need to decide which hosts to hit first. # Now we need to decide which hosts to hit first.
@ -309,7 +365,7 @@ class FederationHandler:
for dom in domains: for dom in domains:
try: try:
await self._federation_event_handler.backfill( await self._federation_event_handler.backfill(
dom, room_id, limit=100, extremities=extremities dom, room_id, limit=100, extremities=extremities_to_request
) )
# If this succeeded then we probably already have the # If this succeeded then we probably already have the
# appropriate stuff. # appropriate stuff.

View File

@ -54,7 +54,7 @@ class RoomBatchHandler:
# it has a larger `depth` but before the successor event because the `stream_ordering` # it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event. # is negative before the successor event.
successor_event_ids = await self.store.get_successor_events( successor_event_ids = await self.store.get_successor_events(
[most_recent_prev_event_id] most_recent_prev_event_id
) )
# If we can't find any successor events, then it's a forward extremity of # If we can't find any successor events, then it's a forward extremity of

View File

@ -695,7 +695,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them. # Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n} return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: async def get_oldest_event_ids_with_depth_in_room(
self, room_id
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the """Gets the oldest events(backwards extremities) in the room along with the
aproximate depth. aproximate depth.
@ -708,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id: Room where we want to find the oldest events room_id: Room where we want to find the oldest events
Returns: Returns:
Map from event_id to depth List of (event_id, depth) tuples
""" """
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
@ -741,7 +743,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False)) txn.execute(sql, (room_id, False))
return dict(txn) return txn.fetchall()
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room", "get_oldest_event_ids_with_depth_in_room",
@ -751,7 +753,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
async def get_insertion_event_backward_extremities_in_room( async def get_insertion_event_backward_extremities_in_room(
self, room_id self, room_id
) -> Dict[str, int]: ) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet. """Get the insertion events we know about that we haven't backfilled yet.
We use this function so that we can compare and see if someones current We use this function so that we can compare and see if someones current
@ -763,7 +765,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id: Room where we want to find the oldest events room_id: Room where we want to find the oldest events
Returns: Returns:
Map from event_id to depth List of (event_id, depth) tuples
""" """
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
@ -778,8 +780,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return txn.fetchall()
return dict(txn)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room", "get_insertion_event_backward_extremities_in_room",
@ -1295,22 +1296,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse() event_results.reverse()
return event_results return event_results
async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given events as a prev event """Fetch all events that have the given event as a prev event
Args: Args:
event_ids: The events to use as the previous events. event_id: The event to search for as a prev_event.
""" """
rows = await self.db_pool.simple_select_many_batch( return await self.db_pool.simple_select_onecol(
table="event_edges", table="event_edges",
column="prev_event_id", keyvalues={"prev_event_id": event_id},
iterable=event_ids, retcol="event_id",
retcols=("event_id",),
desc="get_successor_events", desc="get_successor_events",
) )
return [row["event_id"] for row in rows]
@wrap_as_background_process("delete_old_forward_extrem_cache") @wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None: async def _delete_old_forward_extrem_cache(self) -> None:
def _delete_old_forward_extrem_cache_txn(txn): def _delete_old_forward_extrem_cache_txn(txn):

View File

@ -419,6 +419,13 @@ async def _event_to_memberships(
return {} return {}
# for each event, get the event_ids of the membership state at those events. # for each event, get the event_ids of the membership state at those events.
#
# TODO: this means that we request the entire membership list. If there are only
# one or two users on this server, and the room is huge, this is very wasteful
# (it means more db work, and churns the *stateGroupMembersCache*).
# It might be that we could extend StateFilter to specify "give me keys matching
# *:<server_name>", to avoid this.
event_to_state_ids = await storage.state.get_state_ids_for_events( event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)), state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)),