Pull out less state when handling gaps mk2 (#12852)

This commit is contained in:
Erik Johnston 2022-05-26 10:48:12 +01:00 committed by GitHub
parent 1b338476af
commit b83bc5fab5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 127 deletions

View file

@ -274,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id,
)
await self._process_received_pdu(origin, pdu, state=None)
await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event(
self, origin: str, event: EventBase
@ -463,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)
@ -512,12 +514,12 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
state = await self._resolve_state_at_missing_prevs(destination, event)
state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state_ids,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@ -767,11 +769,12 @@ class FederationEventHandler:
return
try:
state = await self._resolve_state_at_missing_prevs(origin, event)
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled
origin, event, state_ids=state_ids, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
@ -781,7 +784,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@ -808,8 +811,8 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the event ids of the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
@ -829,7 +832,7 @@ class FederationEventHandler:
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: event}
try:
# Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
@ -849,40 +852,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
remote_state = await self._get_state_after_missing_prev_event(
dest, room_id, p
remote_state_map = (
await self._get_state_ids_after_missing_prev_event(
dest, room_id, p
)
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@ -894,14 +880,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map
async def _get_state_after_missing_prev_event(
async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@ -910,7 +896,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
The event ids of the state *after* the given event.
"""
(
state_event_ids,
@ -925,19 +911,17 @@ class FederationEventHandler:
len(auth_event_ids),
)
# start by just trying to fetch the events from the store
# Start by checking events we already have in the DB
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)
missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@ -948,7 +932,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@ -974,47 +958,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
# We now need to fill out the state map, which involves fetching the
# type and state key for each event ID in the state.
state_map = {}
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
#
# This can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
del fetched_events[bad_event_id]
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@ -1022,14 +1010,12 @@ class FederationEventHandler:
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state
return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@ -1056,7 +1042,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state_ids: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@ -1078,7 +1064,7 @@ class FederationEventHandler:
event: event to be persisted
state: Normally None, but if we are handling a gap in the graph
state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
@ -1090,7 +1076,8 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event,
state_ids_before_event=state_ids,
)
context = await self._check_event_auth(
origin,
@ -1107,7 +1094,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state, origin=origin)
await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1589,7 +1576,7 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state_ids: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@ -1597,7 +1584,7 @@ class FederationEventHandler:
Args:
event
state: The state at the event if we don't have all the event's prev events
state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@ -1613,7 +1600,7 @@ class FederationEventHandler:
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state".
if state is not None:
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
@ -1626,17 +1613,20 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_storage.get_state_groups(
state_sets_d = await self._state_storage.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids

View file

@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.types import (
MutableStateMap,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@ -1022,8 +1029,35 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)
state_map_for_event: MutableStateMap[str] = {}
for state_id in state_event_ids:
data = metadata.get(state_id)
if data is None:
# We're trying to persist a new historical batch of events
# with the given state, e.g. via
# `RoomBatchSendEventRestServlet`. The state can be inferred
# by Synapse or set directly by the client.
#
# Either way, we should have persisted all the state before
# getting here.
raise Exception(
f"State event {state_id} not found in DB,"
" Synapse should have persisted it before using it."
)
if data.state_key is None:
raise Exception(
f"Trying to set non-state event {state_id} as state"
)
state_map_for_event[(data.event_type, data.state_key)] = state_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
)
else:
context = await self.state.compute_event_context(event)