diff --git a/synapse/state.py b/synapse/state.py index daec983dc..147416fd8 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -160,14 +160,14 @@ class StateHandler(object): else: context.current_state_ids = {} context.prev_state_events = [] - context.state_group = None + context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: context.current_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = None + context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) @@ -193,7 +193,10 @@ class StateHandler(object): group, curr_state = ret context.current_state_ids = curr_state - context.state_group = group if not event.is_state() else None + if event.is_state() or group is None: + context.state_group = self.store.get_next_state_group() + else: + context.state_group = group if event.is_state(): key = (event.type, event.state_key) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bc1bc97e1..1a7d4c519 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -276,13 +276,6 @@ class EventsStore(SQLBaseStore): events_and_contexts, stream_orderings ): event.internal_metadata.stream_ordering = stream - # Assign a state group_id in case a new id is needed for - # this context. In theory we only need to assign this - # for contexts that have current_state and aren't outliers - # but that make the code more complicated. Assigning an ID - # per event only causes the state_group_ids to grow as fast - # as the stream_ordering so in practise shouldn't be a problem. - context.new_state_group_id = self._state_groups_id_gen.get_next() chunks = [ events_and_contexts[x:x + 100] @@ -309,7 +302,6 @@ class EventsStore(SQLBaseStore): try: with self._stream_id_gen.get_next() as stream_ordering: event.internal_metadata.stream_ordering = stream_ordering - context.new_state_group_id = self._state_groups_id_gen.get_next() yield self.runInteraction( "persist_event", self._persist_event_txn, @@ -523,7 +515,7 @@ class EventsStore(SQLBaseStore): # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group or context.new_state_group_id + state_group_id = context.state_group self._simple_insert_txn( txn, table="ex_outlier_stream", diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 044235328..56bfdc0b5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,6 +83,14 @@ class StateStore(SQLBaseStore): for group, event_id_map in group_to_ids.items() }) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups_state WHERE state_group = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: @@ -92,8 +100,10 @@ class StateStore(SQLBaseStore): if context.current_state_ids is None: continue - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + state_groups[event.event_id] = context.state_group + + if self._have_persisted_state_group_txn(txn, context.state_group): + logger.info("Already persisted state_group: %r", context.state_group) continue state_event_ids = dict(context.current_state_ids) @@ -101,13 +111,11 @@ class StateStore(SQLBaseStore): if event.is_state(): state_event_ids[(event.type, event.state_key)] = event.event_id - state_group = context.new_state_group_id - self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, @@ -118,7 +126,7 @@ class StateStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": state_group, + "state_group": context.state_group, "room_id": event.room_id, "type": key[0], "state_key": key[1], @@ -127,7 +135,6 @@ class StateStore(SQLBaseStore): for key, state_id in state_event_ids.items() ], ) - state_groups[event.event_id] = state_group self._simple_insert_many_txn( txn, @@ -526,3 +533,6 @@ class StateStore(SQLBaseStore): return self.runInteraction( "get_all_new_state_groups", get_all_new_state_groups_txn ) + + def get_next_state_group(self): + return self._state_groups_id_gen.get_next()