Pull out less state when handling gaps mk2 (#12852)

This commit is contained in:
Erik Johnston 2022-05-26 10:48:12 +01:00 committed by GitHub
parent 1b338476af
commit b83bc5fab5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 127 deletions

1
changelog.d/12852.misc Normal file
View File

@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.

View File

@ -274,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id, affected=pdu.event_id,
) )
await self._process_received_pdu(origin, pdu, state=None) await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event( async def on_send_membership_event(
self, origin: str, event: EventBase self, origin: str, event: EventBase
@ -463,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context( context = await self._state_handler.compute_event_context(
event, event,
old_state=state, state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state, partial_state=partial_state,
) )
@ -512,12 +514,12 @@ class FederationEventHandler:
# #
# This is the same operation as we do when we receive a regular event # This is the same operation as we do when we receive a regular event
# over federation. # over federation.
state = await self._resolve_state_at_missing_prevs(destination, event) state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be # build a new state group for it if need be
context = await self._state_handler.compute_event_context( context = await self._state_handler.compute_event_context(
event, event,
old_state=state, state_ids_before_event=state_ids,
) )
if context.partial_state: if context.partial_state:
# this can happen if some or all of the event's prev_events still have # this can happen if some or all of the event's prev_events still have
@ -767,11 +769,12 @@ class FederationEventHandler:
return return
try: try:
state = await self._resolve_state_at_missing_prevs(origin, event) state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state # not return partial state
await self._process_received_pdu( await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled origin, event, state_ids=state_ids, backfilled=backfilled
) )
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
@ -781,7 +784,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs( async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]: ) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events. """Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and This is used when we have pulled a batch of events from a remote server, and
@ -808,8 +811,8 @@ class FederationEventHandler:
event: an event to check for missing prevs. event: an event to check for missing prevs.
Returns: Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of if we already had all the prev events, `None`. Otherwise, returns
the events in the state at `event`. the event ids of the state at `event`.
""" """
room_id = event.room_id room_id = event.room_id
event_id = event.event_id event_id = event.event_id
@ -829,7 +832,7 @@ class FederationEventHandler:
) )
# Calculate the state after each of the previous events, and # Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
event_map = {event_id: event}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen) ours = await self._state_storage.get_state_groups_ids(room_id, seen)
@ -849,40 +852,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped # auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client. # by the get_pdu_cache in federation_client.
remote_state = await self._get_state_after_missing_prev_event( remote_state_map = (
await self._get_state_ids_after_missing_prev_event(
dest, room_id, p dest, room_id, p
) )
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map) state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
room_version = await self._store.get_room_version_id(room_id) room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store( state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id, room_id,
room_version, room_version,
state_maps, state_maps,
event_map, event_map={event_id: event},
state_res_store=StateResolutionStore(self._store), state_res_store=StateResolutionStore(self._store),
) )
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception: except Exception:
logger.warning( logger.warning(
"Error attempting to resolve state at missing prev_events", "Error attempting to resolve state at missing prev_events",
@ -894,14 +880,14 @@ class FederationEventHandler:
"We can't get valid state history.", "We can't get valid state history.",
affected=event_id, affected=event_id,
) )
return state return state_map
async def _get_state_after_missing_prev_event( async def _get_state_ids_after_missing_prev_event(
self, self,
destination: str, destination: str,
room_id: str, room_id: str,
event_id: str, event_id: str,
) -> List[EventBase]: ) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver. """Requests all of the room state at a given event from a remote homeserver.
Args: Args:
@ -910,7 +896,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at. event_id: The id of the event we want the state at.
Returns: Returns:
A list of events in the state, including the event itself The event ids of the state *after* the given event.
""" """
( (
state_event_ids, state_event_ids,
@ -925,19 +911,17 @@ class FederationEventHandler:
len(auth_event_ids), len(auth_event_ids),
) )
# start by just trying to fetch the events from the store # Start by checking events we already have in the DB
desired_events = set(state_event_ids) desired_events = set(state_event_ids)
desired_events.add(event_id) desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events)) logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events( have_events = await self._store.have_seen_events(room_id, desired_events)
desired_events, allow_rejected=True
)
missing_desired_events = desired_events - fetched_events.keys() missing_desired_events = desired_events - have_events
logger.debug( logger.debug(
"We are missing %i events (got %i)", "We are missing %i events (got %i)",
len(missing_desired_events), len(missing_desired_events),
len(fetched_events), len(have_events),
) )
# We probably won't need most of the auth events, so let's just check which # We probably won't need most of the auth events, so let's just check which
@ -948,7 +932,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the # already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need. # federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update( missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events) await self._store.have_seen_events(room_id, missing_auth_events)
) )
@ -974,47 +958,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events destination=destination, room_id=room_id, event_ids=missing_events
) )
# we need to make sure we re-load from the database to get the rejected # We now need to fill out the state map, which involves fetching the
# state correct. # type and state key for each event ID in the state.
fetched_events.update( state_map = {}
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
# check for events which were in the wrong room. event_metadata = await self._store.get_metadata_for_events(state_event_ids)
# for state_event_id, metadata in event_metadata.items():
# this can happen if a remote server claims that the state or if metadata.room_id != room_id:
# auth_events at an event in room A are actually events in room B
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time # This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the # after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set. # bad events from the returned state set.
#
# This can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
logger.warning( logger.warning(
"Remote server %s claims event %s in room %s is an auth/state " "Remote server %s claims event %s in room %s is an auth/state "
"event in room %s", "event in room %s",
destination, destination,
bad_event_id, state_event_id,
bad_room_id, metadata.room_id,
room_id, room_id,
) )
continue
del fetched_events[bad_event_id] if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem. # if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id) remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event: if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,)) raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker # missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete # XXX: this doesn't sound right? it means that we'll end up with incomplete
# state. # state.
failed_to_fetch = desired_events - fetched_events.keys() failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch: if failed_to_fetch:
logger.warning( logger.warning(
"Failed to fetch missing state events for %s %s", "Failed to fetch missing state events for %s %s",
@ -1022,14 +1010,12 @@ class FederationEventHandler:
failed_to_fetch, failed_to_fetch,
) )
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None: if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event) state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state return state_map
async def _get_state_and_persist( async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str self, destination: str, room_id: str, event_id: str
@ -1056,7 +1042,7 @@ class FederationEventHandler:
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state_ids: Optional[StateMap[str]],
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
"""Called when we have a new non-outlier event. """Called when we have a new non-outlier event.
@ -1078,7 +1064,7 @@ class FederationEventHandler:
event: event to be persisted event: event to be persisted
state: Normally None, but if we are handling a gap in the graph state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the (ie, we are missing one or more prev_events), the resolved state at the
event event
@ -1090,7 +1076,8 @@ class FederationEventHandler:
try: try:
context = await self._state_handler.compute_event_context( context = await self._state_handler.compute_event_context(
event, old_state=state event,
state_ids_before_event=state_ids,
) )
context = await self._check_event_auth( context = await self._check_event_auth(
origin, origin,
@ -1107,7 +1094,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event # For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we # passes auth based on the current state. If it doesn't then we
# "soft-fail" the event. # "soft-fail" the event.
await self._check_for_soft_fail(event, state, origin=origin) await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled) await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1589,7 +1576,7 @@ class FederationEventHandler:
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, self,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state_ids: Optional[StateMap[str]],
origin: str, origin: str,
) -> None: ) -> None:
"""Checks if we should soft fail the event; if so, marks the event as """Checks if we should soft fail the event; if so, marks the event as
@ -1597,7 +1584,7 @@ class FederationEventHandler:
Args: Args:
event event
state: The state at the event if we don't have all the event's prev events state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from. origin: The host the event originates from.
""" """
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@ -1613,7 +1600,7 @@ class FederationEventHandler:
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state". # Calculate the "current state".
if state is not None: if state_ids is not None:
# If we're explicitly given the state then we won't have all the # If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case # prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for # we want to be a little careful as we might have been down for
@ -1626,17 +1613,20 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets_d = await self._state_storage.get_state_groups( state_sets_d = await self._state_storage.get_state_groups_ids(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state) state_sets.append(state_ids)
current_states = await self._state_handler.resolve_events( current_state_ids = (
room_version, state_sets, event await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
) )
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else: else:
current_state_ids = await self._state_handler.get_current_state_ids( current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids event.room_id, latest_event_ids=extrem_ids

View File

@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.types import (
MutableStateMap,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -1022,8 +1029,35 @@ class EventCreationHandler:
# #
# TODO(faster_joins): figure out how this works, and make sure that the # TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete. # old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids) metadata = await self.store.get_metadata_for_events(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
state_map_for_event: MutableStateMap[str] = {}
for state_id in state_event_ids:
data = metadata.get(state_id)
if data is None:
# We're trying to persist a new historical batch of events
# with the given state, e.g. via
# `RoomBatchSendEventRestServlet`. The state can be inferred
# by Synapse or set directly by the client.
#
# Either way, we should have persisted all the state before
# getting here.
raise Exception(
f"State event {state_id} not found in DB,"
" Synapse should have persisted it before using it."
)
if data.state_key is None:
raise Exception(
f"Trying to set non-state event {state_id} as state"
)
state_map_for_event[(data.event_type, data.state_key)] = state_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
)
else: else:
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)

View File

@ -261,7 +261,7 @@ class StateHandler:
async def compute_event_context( async def compute_event_context(
self, self,
event: EventBase, event: EventBase,
old_state: Optional[Iterable[EventBase]] = None, state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False, partial_state: bool = False,
) -> EventContext: ) -> EventContext:
"""Build an EventContext structure for a non-outlier event. """Build an EventContext structure for a non-outlier event.
@ -273,12 +273,12 @@ class StateHandler:
Args: Args:
event: event:
old_state: The state at the event if it can't be state_ids_before_event: The event ids of the state before the event if
calculated from existing events. This is normally only specified it can't be calculated from existing events. This is normally
when receiving an event from federation where we don't have the only specified when receiving an event from federation where we
prev events for, e.g. when backfilling. don't have the prev events, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical partial_state: True if `state_ids_before_event` is partial and omits
membership events non-critical membership events
Returns: Returns:
The event context. The event context.
""" """
@ -286,13 +286,11 @@ class StateHandler:
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()
# #
# first of all, figure out the state before the event # first of all, figure out the state before the event, unless we
# already have it.
# #
if old_state: if state_ids_before_event:
# if we're given the state before the event, then we use that # if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
state_group_before_event = None state_group_before_event = None
state_group_before_event_prev_group = None state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None deltas_to_state_group_before_event = None

View File

@ -16,6 +16,8 @@ import collections.abc
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause,
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
room_id: str
event_type: str
state_key: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id) v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v: if not v:
@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version return room_version
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events.
This method is a faster alternative than fetching the full events from
the DB, and should be used when the full event is not needed.
Returns metadata for rejected and redacted events. Events that have not
been persisted are omitted from the returned dict.
"""
def get_metadata_for_events_txn(
txn: LoggingTransaction,
batch_ids: Collection[str],
) -> Dict[str, EventMetadata]:
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", batch_ids
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}
result_map: Dict[str, EventMetadata] = {}
for batch_ids in batch_iter(event_ids, 1000):
result_map.update(
await self.db_pool.runInteraction(
"get_metadata_for_events",
get_metadata_for_events_txn,
batch_ids=batch_ids,
)
)
return result_map
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists. """Get the predecessor of an upgraded room if it exists.
Otherwise return None. Otherwise return None.

View File

@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event. # federation handler wanting to backfill the fake event.
self.get_success( self.get_success(
federation_event_handler._process_received_pdu( federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state self.OTHER_SERVER_NAME,
event,
state_ids={
(e.type, e.state_key): e.event_id for e in current_state
},
) )
) )

View File

@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None): def persist_event(self, event, state=None):
"""Persist the event, with optional state""" """Persist the event, with optional state"""
context = self.get_success( context = self.get_success(
self.state.compute_event_context(event, old_state=state) self.state.compute_event_context(event, state_ids_before_event=state)
) )
self.get_success(self.persistence.persist_event(event, context)) self.get_success(self.persistence.persist_event(event, context))
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then # setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state. # include it, and so the resolved state won't match the new state.
state_before_gap = dict( state_before_gap = dict(
self.get_success(self.state.get_current_state(self.room_id)) self.get_success(self.state.get_current_state_ids(self.room_id))
) )
state_before_gap.pop(("m.room.history_visibility", "")) state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success( context = self.get_success(
self.state.compute_event_context( self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values() remote_event_2,
state_ids_before_event=state_before_gap,
) )
) )
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id]) self.assert_extremities([remote_event_2.event_id])
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id]) self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6, RoomVersions.V6,
) )
state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values()) self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event. # Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id]) self.assert_extremities([local_message_event_id, remote_event_2.event_id])

View File

@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
] ]
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state) self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
) )
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
] ]
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state) self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
) )
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())