Batch fetch _get_state_groups_from_groups

This commit is contained in:
Erik Johnston 2016-02-10 13:24:42 +00:00
parent 24f00a6c33
commit 5189bfdef4

View File

@ -171,15 +171,10 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
def _get_state_groups_from_groups(self, groups_and_types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> state event ids """Returns dictionary state_group -> state event ids
Args:
groups_and_types (list): list of 2-tuple (`group`, `types`)
""" """
def f(txn): def f(txn, groups):
results = {}
for group, types in groups_and_types:
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)), " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
@ -188,23 +183,30 @@ class StateStore(SQLBaseStore):
where_clause = "" where_clause = ""
sql = ( sql = (
"SELECT event_id FROM state_groups_state WHERE" "SELECT state_group, event_id FROM state_groups_state WHERE"
" state_group = ? %s" " state_group IN (%s) %s" % (
) % (where_clause,) ",".join("?" for _ in groups),
where_clause,
)
)
args = [group] args = list(groups)
if types is not None: if types is not None:
args.extend([i for typ in types for i in typ]) args.extend([i for typ in types for i in typ])
txn.execute(sql, args) txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
results[group] = [r[0] for r in txn.fetchall()] results = {}
for row in rows:
results.setdefault(row["state_group"], []).append(row["event_id"])
return results return results
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
return self.runInteraction( return self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
f, f, chunk
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -349,7 +351,7 @@ class StateStore(SQLBaseStore):
all events are returned. all events are returned.
""" """
results = {} results = {}
missing_groups_and_types = [] missing_groups = []
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache( state_dict, missing_types, got_all = self._get_some_state_from_cache(
@ -358,7 +360,7 @@ class StateStore(SQLBaseStore):
results[group] = state_dict results[group] = state_dict
if not got_all: if not got_all:
missing_groups_and_types.append((group, missing_types)) missing_groups.append(group)
else: else:
for group in set(groups): for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache( state_dict, got_all = self._get_all_state_from_cache(
@ -367,9 +369,9 @@ class StateStore(SQLBaseStore):
results[group] = state_dict results[group] = state_dict
if not got_all: if not got_all:
missing_groups_and_types.append((group, None)) missing_groups.append(group)
if not missing_groups_and_types: if not missing_groups:
defer.returnValue({ defer.returnValue({
group: { group: {
type_tuple: event type_tuple: event
@ -383,7 +385,7 @@ class StateStore(SQLBaseStore):
cache_seq_num = self._state_group_cache.sequence cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups( group_state_dict = yield self._get_state_groups_from_groups(
missing_groups_and_types missing_groups, types
) )
state_events = yield self._get_events( state_events = yield self._get_events(