Faster joins: persist to database (#12012)

When we get a partial_state response from send_join, store information in the
database about it:
 * store a record about the room as a whole having partial state, and stash the
   list of member servers too.
 * flag the join event itself as having partial state
 * also, for any new events whose prev-events are partial-stated, note that
   they will *also* be partial-stated.

We don't yet make any attempt to interpret this data, so API calls (and a bunch
of other things) are just going to get incorrect data.
This commit is contained in:
Richard van der Hoff 2022-03-01 12:49:54 +00:00 committed by GitHub
parent 4ccc2d09aa
commit e2e1d90a5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 297 additions and 32 deletions

1
changelog.d/12012.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database.

View File

@ -101,6 +101,9 @@ class EventContext:
As with _current_state_ids, this is a private attribute. It should be As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids. accessed via get_prev_state_ids.
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
""" """
rejected: Union[bool, str] = False rejected: Union[bool, str] = False
@ -113,12 +116,15 @@ class EventContext:
_current_state_ids: Optional[StateMap[str]] = None _current_state_ids: Optional[StateMap[str]] = None
_prev_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None
partial_state: bool = False
@staticmethod @staticmethod
def with_state( def with_state(
state_group: Optional[int], state_group: Optional[int],
state_group_before_event: Optional[int], state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]], current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None, prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None, delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext": ) -> "EventContext":
@ -129,6 +135,7 @@ class EventContext:
state_group_before_event=state_group_before_event, state_group_before_event=state_group_before_event,
prev_group=prev_group, prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,
partial_state=partial_state,
) )
@staticmethod @staticmethod
@ -170,6 +177,7 @@ class EventContext:
"prev_group": self.prev_group, "prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids), "delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None, "app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
} }
@staticmethod @staticmethod
@ -196,6 +204,7 @@ class EventContext:
prev_group=input["prev_group"], prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]), delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"], rejected=input["rejected"],
partial_state=input.get("partial_state", False),
) )
app_service_id = input["app_service_id"] app_service_id = input["app_service_id"]

View File

@ -519,8 +519,17 @@ class FederationHandler:
state_events=state, state_events=state,
) )
if ret.partial_state:
await self.store.store_partial_state_room(room_id, ret.servers_in_room)
max_stream_id = await self._federation_event_handler.process_remote_join( max_stream_id = await self._federation_event_handler.process_remote_join(
origin, room_id, auth_chain, state, event, room_version_obj origin,
room_id,
auth_chain,
state,
event,
room_version_obj,
partial_state=ret.partial_state,
) )
# We wait here until this instance has seen the events come down # We wait here until this instance has seen the events come down

View File

@ -397,6 +397,7 @@ class FederationEventHandler:
state: List[EventBase], state: List[EventBase],
event: EventBase, event: EventBase,
room_version: RoomVersion, room_version: RoomVersion,
partial_state: bool,
) -> int: ) -> int:
"""Persists the events returned by a send_join """Persists the events returned by a send_join
@ -412,6 +413,7 @@ class FederationEventHandler:
event event
room_version: The room version we expect this room to have, and room_version: The room version we expect this room to have, and
will raise if it doesn't match the version in the create event. will raise if it doesn't match the version in the create event.
partial_state: True if the state omits non-critical membership events
Returns: Returns:
The stream ID after which all events have been persisted. The stream ID after which all events have been persisted.
@ -453,10 +455,14 @@ class FederationEventHandler:
) )
# and now persist the join event itself. # and now persist the join event itself.
logger.info("Peristing join-via-remote %s", event) logger.info(
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context( context = await self._state_handler.compute_event_context(
event, old_state=state event,
old_state=state,
partial_state=partial_state,
) )
context = await self._check_event_auth(origin, event, context) context = await self._check_event_auth(origin, event, context)
@ -698,6 +704,8 @@ class FederationEventHandler:
try: try:
state = await self._resolve_state_at_missing_prevs(origin, event) state = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
await self._process_received_pdu( await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled origin, event, state=state, backfilled=backfilled
) )
@ -1791,6 +1799,7 @@ class FederationEventHandler:
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
prev_group=prev_group, prev_group=prev_group,
delta_ids=state_updates, delta_ids=state_updates,
partial_state=context.partial_state,
) )
async def _run_push_actions_and_persist_event( async def _run_push_actions_and_persist_event(

View File

@ -992,6 +992,8 @@ class EventCreationHandler:
and full_state_ids_at_event and full_state_ids_at_event
and builder.internal_metadata.is_historical() and builder.internal_metadata.is_historical()
): ):
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(full_state_ids_at_event) old_state = await self.store.get_events_as_list(full_state_ids_at_event)
context = await self.state.compute_event_context(event, old_state=old_state) context = await self.state.compute_event_context(event, old_state=old_state)
else: else:

View File

@ -258,7 +258,10 @@ class StateHandler:
return await self.store.get_joined_hosts(room_id, entry) return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context( async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
partial_state: bool = False,
) -> EventContext: ) -> EventContext:
"""Build an EventContext structure for a non-outlier event. """Build an EventContext structure for a non-outlier event.
@ -273,6 +276,8 @@ class StateHandler:
calculated from existing events. This is normally only specified calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling. prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
Returns: Returns:
The event context. The event context.
""" """
@ -295,8 +300,28 @@ class StateHandler:
else: else:
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
# partial_state should not be set explicitly in this case:
# we work it out dynamically
assert not partial_state
# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
# their states - but I don't think that really matters; it just means we
# might redundantly recalculate the state for this event later.)
prev_event_ids = event.prev_event_ids()
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
if any(incomplete_prev_events.values()):
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
entry = await self.resolve_state_groups_for_events( entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids() event.room_id, event.prev_event_ids()
) )
@ -342,6 +367,7 @@ class StateHandler:
prev_state_ids=state_ids_before_event, prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event_prev_group, prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event, delta_ids=deltas_to_state_group_before_event,
partial_state=partial_state,
) )
# #
@ -373,6 +399,7 @@ class StateHandler:
prev_state_ids=state_ids_before_event, prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event, prev_group=state_group_before_event,
delta_ids=delta_ids, delta_ids=delta_ids,
partial_state=partial_state,
) )
@measure_func() @measure_func()

View File

@ -2145,6 +2145,14 @@ class PersistEventsStore:
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
# double-check that we don't have any events that claim to be outliers
# *and* have partial state (which is meaningless: we should have no
# state at all for an outlier)
if context.partial_state:
raise ValueError(
"Outlier event %s claims to have partial state", event.event_id
)
continue continue
# if the event was rejected, just give it the same state as its # if the event was rejected, just give it the same state as its
@ -2155,6 +2163,23 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
# if we have partial state for these events, record the fact. (This happens
# here rather than in _store_event_txn because it also needs to happen when
# we de-outlier an event.)
self.db_pool.simple_insert_many_txn(
txn,
table="partial_state_events",
keys=("room_id", "event_id"),
values=[
(
event.room_id,
event.event_id,
)
for event, ctx in events_and_contexts
if ctx.partial_state
],
)
self.db_pool.simple_upsert_many_txn( self.db_pool.simple_upsert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",

View File

@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore):
"get_event_id_for_timestamp_txn", "get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn, get_event_id_for_timestamp_txn,
) )
@cachedList("is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
"""Checks which of the given events have partial state"""
result = await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
iterable=event_ids,
retcols=["event_id"],
desc="get_partial_state_events",
)
# convert the result to a dict, to make @cachedList work
partial = {r["event_id"] for r in result}
return {e_id: e_id in partial for e_id in event_ids}
@cached()
async def is_partial_state_event(self, event_id: str) -> bool:
"""Checks if the given event has partial state"""
result = await self.db_pool.simple_select_one_onecol(
table="partial_state_events",
keyvalues={"event_id": event_id},
retcol="1",
allow_none=True,
desc="is_partial_state_event",
)
return result is not None

View File

@ -20,6 +20,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Collection,
Dict, Dict,
List, List,
Optional, Optional,
@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
lock=False, lock=False,
) )
async def store_partial_state_room(
self,
room_id: str,
servers: Collection[str],
) -> None:
"""Mark the given room as containing events with partial state
Args:
room_id: the ID of the room
servers: other servers known to be in the room
"""
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
room_id,
servers,
)
@staticmethod
def _store_partial_state_room_txn(
txn: LoggingTransaction, room_id: str, servers: Collection[str]
) -> None:
DatabasePool.simple_insert_txn(
txn,
table="partial_state_rooms",
values={
"room_id": room_id,
},
)
DatabasePool.simple_insert_many_txn(
txn,
table="partial_state_rooms_servers",
keys=("room_id", "server_name"),
values=((room_id, s) for s in servers),
)
async def maybe_store_room_on_outlier_membership( async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion self, room_id: str, room_version: RoomVersion
) -> None: ) -> None:

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- rooms which we have done a partial-state-style join to
CREATE TABLE IF NOT EXISTS partial_state_rooms (
room_id TEXT PRIMARY KEY,
FOREIGN KEY(room_id) REFERENCES rooms(room_id)
);
-- a list of remote servers we believe are in the room
CREATE TABLE IF NOT EXISTS partial_state_rooms_servers (
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
server_name TEXT NOT NULL,
UNIQUE(room_id, server_name)
);
-- a list of events with partial state. We can't store this in the `events` table
-- itself, because `events` is meant to be append-only.
CREATE TABLE IF NOT EXISTS partial_state_events (
-- the room_id is denormalised for efficient indexing (the canonical source is `events`)
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
event_id TEXT NOT NULL REFERENCES events(event_id),
UNIQUE(event_id)
);
CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx
ON partial_state_events (room_id);

View File

@ -0,0 +1,72 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This migration adds triggers to the partial_state_events tables to enforce uniqueness
Triggers cannot be expressed in .sql files, so we have to use a separate file.
"""
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Cursor
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
# complain if the room_id in partial_state_events doesn't match
# that in `events`. We already have a fk constraint which ensures that the event
# exists in `events`, so all we have to do is raise if there is a row with a
# matching stream_ordering but not a matching room_id.
if isinstance(database_engine, Sqlite3Engine):
cur.execute(
"""
CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id
BEFORE INSERT ON partial_state_events
FOR EACH ROW
BEGIN
SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events')
WHERE EXISTS (
SELECT 1 FROM events
WHERE events.event_id = NEW.event_id
AND events.room_id != NEW.room_id
);
END;
"""
)
elif isinstance(database_engine, PostgresEngine):
cur.execute(
"""
CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$
BEGIN
IF EXISTS (
SELECT 1 FROM events
WHERE events.event_id = NEW.event_id
AND events.room_id != NEW.room_id
) THEN
RAISE EXCEPTION 'Incorrect room_id in partial_state_events';
END IF;
RETURN NEW;
END;
$BODY$ LANGUAGE plpgsql;
"""
)
cur.execute(
"""
CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events
FOR EACH ROW
EXECUTE PROCEDURE check_partial_state_events()
"""
)
else:
raise NotImplementedError("Unknown database engine")

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import Collection, Dict, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -70,7 +70,7 @@ def create_event(
return event return event
class StateGroupStore: class _DummyStore:
def __init__(self): def __init__(self):
self._event_to_state_group = {} self._event_to_state_group = {}
self._group_to_state = {} self._group_to_state = {}
@ -105,6 +105,11 @@ class StateGroupStore:
if e_id in self._event_id_to_event if e_id in self._event_id_to_event
} }
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
return {e: False for e in event_ids}
async def get_state_group_delta(self, name): async def get_state_group_delta(self, name):
return None, None return None, None
@ -157,8 +162,8 @@ class Graph:
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = StateGroupStore() self.dummy_store = _DummyStore()
storage = Mock(main=self.store, state=self.store) storage = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock( hs = Mock(
spec_set=[ spec_set=[
"config", "config",
@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase):
] ]
) )
hs.config = default_config("tesths", True) hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(main=self.store) hs.get_datastores.return_value = Mock(main=self.dummy_store)
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase):
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
) )
self.store.register_events(graph.walk()) self.dummy_store.register_events(graph.walk())
context_store: dict[str, EventContext] = {} context_store: dict[str, EventContext] = {}
@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event) self.state.compute_event_context(event)
) )
self.store.register_event_context(event, context) self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
ctx_c = context_store["C"] ctx_c = context_store["C"]
@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase):
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
) )
self.store.register_events(graph.walk()) self.dummy_store.register_events(graph.walk())
context_store = {} context_store = {}
@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event) self.state.compute_event_context(event)
) )
self.store.register_event_context(event, context) self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
# C ends up winning the resolution between B and C # C ends up winning the resolution between B and C
@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase):
edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
) )
self.store.register_events(graph.walk()) self.dummy_store.register_events(graph.walk())
context_store = {} context_store = {}
@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event) self.state.compute_event_context(event)
) )
self.store.register_event_context(event, context) self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
# C ends up winning the resolution between C and D because bans win over other # C ends up winning the resolution between C and D because bans win over other
@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges) self._add_depths(nodes, edges)
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
self.store.register_events(graph.walk()) self.dummy_store.register_events(graph.walk())
context_store = {} context_store = {}
@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred( context = yield defer.ensureDeferred(
self.state.compute_event_context(event) self.state.compute_event_context(event)
) )
self.store.register_event_context(event, context) self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
# B ends up winning the resolution between B and C because power levels # B ends up winning the resolution between B and C because power levels
@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase):
] ]
group_name = yield defer.ensureDeferred( group_name = yield defer.ensureDeferred(
self.store.store_state_group( self.dummy_store.store_state_group(
prev_event_id, prev_event_id,
event.room_id, event.room_id,
None, None,
@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase):
{(e.type, e.state_key): e.event_id for e in old_state}, {(e.type, e.state_key): e.event_id for e in old_state},
) )
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
context = yield defer.ensureDeferred(self.state.compute_event_context(event)) context = yield defer.ensureDeferred(self.state.compute_event_context(event))
@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase):
] ]
group_name = yield defer.ensureDeferred( group_name = yield defer.ensureDeferred(
self.store.store_state_group( self.dummy_store.store_state_group(
prev_event_id, prev_event_id,
event.room_id, event.room_id,
None, None,
@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase):
{(e.type, e.state_key): e.event_id for e in old_state}, {(e.type, e.state_key): e.event_id for e in old_state},
) )
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
context = yield defer.ensureDeferred(self.state.compute_event_context(event)) context = yield defer.ensureDeferred(self.state.compute_event_context(event))
@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
self.store.register_events(old_state_1) self.dummy_store.register_events(old_state_1)
self.store.register_events(old_state_2) self.dummy_store.register_events(old_state_2)
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore() store = _DummyStore()
store.register_events(old_state_1) store.register_events(old_state_1)
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.dummy_store.get_events = store.get_events
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
store = StateGroupStore() store = _DummyStore()
store.register_events(old_state_1) store.register_events(old_state_1)
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.dummy_store.get_events = store.get_events
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase):
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
): ):
sg1 = yield defer.ensureDeferred( sg1 = yield defer.ensureDeferred(
self.store.store_state_group( self.dummy_store.store_state_group(
prev_event_id_1, prev_event_id_1,
event.room_id, event.room_id,
None, None,
@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase):
{(e.type, e.state_key): e.event_id for e in old_state_1}, {(e.type, e.state_key): e.event_id for e in old_state_1},
) )
) )
self.store.register_event_id_state_group(prev_event_id_1, sg1) self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = yield defer.ensureDeferred( sg2 = yield defer.ensureDeferred(
self.store.store_state_group( self.dummy_store.store_state_group(
prev_event_id_2, prev_event_id_2,
event.room_id, event.room_id,
None, None,
@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase):
{(e.type, e.state_key): e.event_id for e in old_state_2}, {(e.type, e.state_key): e.event_id for e in old_state_2},
) )
) )
self.store.register_event_id_state_group(prev_event_id_2, sg2) self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
result = yield defer.ensureDeferred(self.state.compute_event_context(event)) result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result return result