Generate state group ids in state layer

This commit is contained in:
Erik Johnston 2016-08-31 10:09:46 +01:00
parent 5dc2a702cf
commit 1bb8ec296d
3 changed files with 24 additions and 19 deletions

View File

@ -160,14 +160,14 @@ class StateHandler(object):
else: else:
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = None context.state_group = self.store.get_next_state_group()
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context.current_state_ids = { context.current_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = None context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
@ -193,7 +193,10 @@ class StateHandler(object):
group, curr_state = ret group, curr_state = ret
context.current_state_ids = curr_state context.current_state_ids = curr_state
context.state_group = group if not event.is_state() else None if event.is_state() or group is None:
context.state_group = self.store.get_next_state_group()
else:
context.state_group = group
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)

View File

@ -276,13 +276,6 @@ class EventsStore(SQLBaseStore):
events_and_contexts, stream_orderings events_and_contexts, stream_orderings
): ):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = self._state_groups_id_gen.get_next()
chunks = [ chunks = [
events_and_contexts[x:x + 100] events_and_contexts[x:x + 100]
@ -309,7 +302,6 @@ class EventsStore(SQLBaseStore):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = self._state_groups_id_gen.get_next()
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
@ -523,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the # Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers. # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id state_group_id = context.state_group
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="ex_outlier_stream", table="ex_outlier_stream",

View File

@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.items()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups_state WHERE state_group = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
@ -92,8 +100,10 @@ class StateStore(SQLBaseStore):
if context.current_state_ids is None: if context.current_state_ids is None:
continue continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue continue
state_event_ids = dict(context.current_state_ids) state_event_ids = dict(context.current_state_ids)
@ -101,13 +111,11 @@ class StateStore(SQLBaseStore):
if event.is_state(): if event.is_state():
state_event_ids[(event.type, event.state_key)] = event.event_id state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group, "id": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
@ -118,7 +126,7 @@ class StateStore(SQLBaseStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": state_group, "state_group": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
@ -127,7 +135,6 @@ class StateStore(SQLBaseStore):
for key, state_id in state_event_ids.items() for key, state_id in state_event_ids.items()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -526,3 +533,6 @@ class StateStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn "get_all_new_state_groups", get_all_new_state_groups_txn
) )
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()