Store state groups separately from events (#2784)

* Split state group persist into seperate storage func

* Add per database engine code for state group id gen

* Move store_state_group to StateReadStore

This allows other workers to use it, and so resolve state.

* Hook up store_state_group

* Fix tests

* Rename _store_mult_state_groups_txn

* Rename StateGroupReadStore

* Remove redundant _have_persisted_state_group_txn

* Update comments

* Comment compute_event_context

* Set start val for state_group_id_seq

... otherwise we try to recreate old state groups

* Update comments

* Don't store state for outliers

* Update comment

* Update docstring as state groups are ints
This commit is contained in:
Erik Johnston 2018-02-06 14:31:24 +00:00 committed by GitHub
parent b31bf0bb51
commit 3d33eef6fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 341 additions and 204 deletions

View File

@ -25,7 +25,9 @@ class EventContext(object):
The current state map excluding the current event. The current state map excluding the current event.
(type, state_key) -> event_id (type, state_key) -> event_id
state_group (int): state group id state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else rejected (bool|str): A rejection reason if the event was rejected, else
False False

View File

@ -1831,8 +1831,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
@ -1913,8 +1913,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
try: try:
@ -1923,11 +1923,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
def _update_context_for_auth_events(self, context, auth_events, @defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
event_key): event_key):
"""Update the state_ids in an event context after auth event resolution """Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args: Args:
event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context context (synapse.events.snapshot.EventContext): event context
to be updated to be updated
@ -1950,7 +1954,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in auth_events.iteritems()
}) })
context.state_group = self.store.get_next_state_group() context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):

View File

@ -19,7 +19,7 @@ from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs) super(SlavedEventStore, self).__init__(db_conn, hs)

View File

@ -183,8 +183,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args: Args:
event (synapse.events.EventBase): event (synapse.events.EventBase):
old_state (dict|None): The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns: Returns:
synapse.events.snapshot.EventContext: synapse.events.snapshot.EventContext:
""" """
@ -208,15 +215,22 @@ class StateHandler(object):
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {} context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = self.store.get_next_state_group()
# We don't store state for outliers, so we don't generate a state
# froup for it.
context.state_group = None
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
# We already have the state, so we don't need to calculate it.
# Let's just correctly fill out the context and create a
# new state group for it.
context = EventContext() context = EventContext()
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
@ -229,6 +243,14 @@ class StateHandler(object):
else: else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=context.current_state_ids,
)
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -242,7 +264,8 @@ class StateHandler(object):
context = EventContext() context = EventContext()
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() # If this is a state event then we need to create a new state
# group for the state after this event.
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
@ -253,24 +276,43 @@ class StateHandler(object):
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
if entry.state_group: if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
context.prev_group = entry.state_group context.prev_group = entry.state_group
context.delta_ids = { context.delta_ids = {
key: event.event_id key: event.event_id
} }
elif entry.prev_group: elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids) context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids
if entry.state_group is None:
entry.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=entry.prev_group,
delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids,
)
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)

View File

@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")

View File

@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]

View File

@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
import struct import struct
import threading
class Sqlite3Engine(object): class Sqlite3Engine(object):
@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module self.module = database_module
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
return return
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]
self._current_state_group_id += 1
return self._current_state_group_id
# Following functions taken from: https://github.com/coleifer/peewee # Following functions taken from: https://github.com/coleifer/peewee

View File

@ -755,9 +755,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
) )
# Insert into the state_groups, state_groups_state, and # Insert into event_to_state_groups.
# event_to_state_groups tables. self._store_event_state_mappings_txn(txn, events_and_contexts)
self._store_mult_state_groups_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
@ -992,10 +991,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that # an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state. # so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and # insert into event_to_state_groups.
# event_to_state_groups tables.
try: try:
self._store_mult_state_groups_txn(txn, ((event, context),)) self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception: except Exception:
logger.exception("") logger.exception("")
raise raise

View File

@ -0,0 +1,37 @@
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
# if we already have some state groups, we want to start making new
# ones with a higher id.
cur.execute("SELECT max(id) FROM state_groups")
row = cur.fetchone()
if row[0] is None:
start_val = 1
else:
start_val = row[0] + 1
cur.execute(
"CREATE SEQUENCE state_group_id_seq START WITH %s",
(start_val, ),
)
def run_upgrade(*args, **kwargs):
pass

View File

@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateGroupReadStore(SQLBaseStore): class StateGroupWorkerStore(SQLBaseStore):
"""The read-only parts of StateGroupStore """The parts of StateGroupStore that can be called from workers.
None of these functions write to the state tables, so are suitable for
including in the SlavedStores.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateGroupReadStore, self).__init__(db_conn, hs) super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@ -549,8 +546,117 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
"""Store a new set of state, returning a newly assigned state group.
class StateStore(StateGroupReadStore, BackgroundUpdateStore): Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
Returns:
Deferred[int]: The state group ID
"""
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
state_group = self.database_engine.get_next_state_group_id(txn)
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
full=True,
)
return state_group
return self.runInteraction("store_state_group", _store_state_group_txn)
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event. """ Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned This is done by the concept of `state groups`. Every event is a assigned
@ -591,27 +697,12 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
def _have_persisted_state_group_txn(self, txn, state_group): def _store_event_state_mappings_txn(self, txn, events_and_contexts):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
continue continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its # if the event was rejected, just give it the same state as its
# predecessor. # predecessor.
if context.rejected: if context.rejected:
@ -620,90 +711,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count return count
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size): def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding

View File

@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
context = EventContext() context = EventContext()
context.current_state_ids = state_ids context.current_state_ids = state_ids
context.prev_state_ids = state_ids context.prev_state_ids = state_ids
elif not backfill: else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event) context = yield state_handler.compute_event_context(event)
else:
context = EventContext()
context.push_actions = push_actions context.push_actions = push_actions

View File

@ -80,14 +80,14 @@ class StateGroupStore(object):
return defer.succeed(groups) return defer.succeed(groups)
def store_state_groups(self, event, context): def store_state_group(self, event_id, room_id, prev_group, delta_ids,
if context.current_state_ids is None: current_state_ids):
return state_group = self._next_group
self._next_group += 1
state_events = dict(context.current_state_ids) self._group_to_state[state_group] = dict(current_state_ids)
self._group_to_state[context.state_group] = state_events return state_group
self._event_to_state_group[event.event_id] = context.state_group
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return {
@ -95,10 +95,19 @@ class StateGroupStore(object):
if e_id in self._event_id_to_event if e_id in self._event_id_to_event
} }
def get_state_group_delta(self, name):
return (None, None)
def register_events(self, events): def register_events(self, events):
for e in events: for e in events:
self._event_id_to_event[e.event_id] = e self._event_id_to_event[e.event_id] = e
def register_event_context(self, event, context):
self._event_to_state_group[event.event_id] = context.state_group
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
class DictObj(dict): class DictObj(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -137,15 +146,7 @@ class Graph(object):
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = StateGroupStore()
spec_set=[
"get_state_groups_ids",
"add_event_hashes",
"get_events",
"get_next_state_group",
"get_state_group_delta",
]
)
hs = Mock(spec_set=[ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock", "get_datastore", "get_auth", "get_state_handler", "get_clock",
"get_state_resolution_handler", "get_state_resolution_handler",
@ -156,9 +157,6 @@ class StateTestCase(unittest.TestCase):
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@ -197,14 +195,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids)) self.assertEqual(2, len(context_store["D"].prev_state_ids))
@ -249,16 +246,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -315,16 +309,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -398,16 +389,13 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges) self._add_depths(nodes, edges)
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -467,7 +455,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event") prev_event_id = "prev_event_id"
event = create_event(
type="test_message", name="event2",
prev_events=[(prev_event_id, {})],
)
old_state = [ old_state = [
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
@ -475,11 +467,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None,
self.store.get_state_groups_ids.return_value = { {(e.type, e.state_key): e.event_id for e in old_state},
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, )
} self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
@ -492,7 +484,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = create_event(type="state", state_key="", name="event") prev_event_id = "prev_event_id"
event = create_event(
type="state", state_key="", name="event2",
prev_events=[(prev_event_id, {})],
)
old_state = [ old_state = [
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
@ -500,11 +496,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None,
self.store.get_state_groups_ids.return_value = { {(e.type, e.state_key): e.event_id for e in old_state},
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, )
} self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
@ -517,7 +513,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = create_event(type="test_message", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test_message", name="event3",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
creation = create_event( creation = create_event(
type=EventTypes.Create, state_key="" type=EventTypes.Create, state_key=""
@ -537,12 +538,12 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore() self.store.register_events(old_state_1)
store.register_events(old_state_1) self.store.register_events(old_state_2)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -550,7 +551,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = create_event(type="test4", state_key="", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test4", state_key="", name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
creation = create_event( creation = create_event(
type=EventTypes.Create, state_key="" type=EventTypes.Create, state_key=""
@ -575,7 +581,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -583,7 +591,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_standard_depth_conflict(self): def test_standard_depth_conflict(self):
event = create_event(type="test4", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test4", name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
member_event = create_event( member_event = create_event(
type=EventTypes.Member, type=EventTypes.Member,
@ -615,7 +628,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual( self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")] old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@ -639,19 +654,26 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_1) store.register_events(old_state_1)
store.register_events(old_state_2) store.register_events(old_state_2)
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual( self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")] old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
) )
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
group_name_1 = "group_name_1" old_state_2):
group_name_2 = "group_name_2" sg1 = self.store.store_state_group(
prev_event_id_1, event.room_id, None, None,
{(e.type, e.state_key): e.event_id for e in old_state_1},
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
self.store.get_state_groups_ids.return_value = { sg2 = self.store.store_state_group(
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, prev_event_id_2, event.room_id, None, None,
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, {(e.type, e.state_key): e.event_id for e in old_state_2},
} )
self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event) return self.state.compute_event_context(event)