diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7eb342674..a82ba1d1d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -307,6 +307,9 @@ class StateStore(SQLBaseStore): def _get_state_groups_from_groups_txn(self, txn, groups, types=None): results = {group: {} for group in groups} + if types is not None: + types = list(set(types)) # deduplicate types list + if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in @@ -379,6 +382,44 @@ class StateStore(SQLBaseStore): next_group = group while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indicies (which we can't add until + # after we finish deduping state, which requires this func) + if types is not None: + args = [next_group] + [i for typ in types for i in typ] + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? %s" % (where_clause,), + args + ) + rows = txn.fetchall() + + results[group].update({ + (typ, state_key): event_id + for typ, state_key, event_id in rows + if (typ, state_key) not in results[group] + }) + + # If the lengths match then we must have all the types, + # so no need to go walk further down the tree. + if len(results[group]) == len(types): + break + else: + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ?", + (next_group,) + ) + rows = txn.fetchall() + + results[group].update({ + (typ, state_key): event_id + for typ, state_key, event_id in rows + if (typ, state_key) not in results[group] + }) + next_group = self._simple_select_one_onecol_txn( txn, table="state_group_edges", @@ -389,26 +430,6 @@ class StateStore(SQLBaseStore): if next_group: group_tree.append(next_group) - sql = (""" - SELECT type, state_key, event_id FROM state_groups_state - INNER JOIN ( - SELECT type, state_key, max(state_group) as state_group - FROM state_groups_state - WHERE state_group IN (%s) %s - GROUP BY type, state_key - ) USING (type, state_key, state_group); - """) % (",".join("?" for _ in group_tree), where_clause,) - - args = list(group_tree) - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] - return results @defer.inlineCallbacks