diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 64c5ae992..19b16ed40 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -283,6 +283,9 @@ class StateStore(SQLBaseStore): def _get_state_for_group_from_cache(self, group, types=None): """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the + list of types that aren't in the cache for that group. """ is_all, state_dict = self._state_group_cache.get(group) @@ -300,29 +303,31 @@ class StateStore(SQLBaseStore): if (typ, state_key) not in state_dict: missing_types.add((typ, state_key)) - if is_all and types is None: - return state_dict, missing_types + if is_all: + missing_types = set() + if types is None: + return state_dict, set(), True - if is_all or (types is not None and not missing_types): - sentinel = object() + sentinel = object() - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return False - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True + def include(typ, state_key): + if types is None: + return True + + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False - return { - k: v - for k, v in state_dict.items() - if v and include(k[0], k[1]) - }, missing_types - - return {}, missing_types + return { + k: v + for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, not missing_types and types is not None @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): @@ -333,25 +338,28 @@ class StateStore(SQLBaseStore): """ results = {} missing_groups_and_types = [] - for group in groups: - state_dict, missing_types = self._get_state_for_group_from_cache( + for group in set(groups): + state_dict, missing_types, got_all = self._get_state_for_group_from_cache( group, types ) - if types is not None and not missing_types: - results[group] = { - key: value - for key, value in state_dict.items() - if value - } - else: + results[group] = state_dict + + if not got_all: missing_groups_and_types.append(( group, missing_types if types else None )) if not missing_groups_and_types: - defer.returnValue(results) + defer.returnValue({ + k: { + key: ev + for key, ev in state.items() + if ev + } + for k, state in results.items() + }) # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence @@ -371,10 +379,15 @@ class StateStore(SQLBaseStore): } for group, state_ids in group_state_dict.items(): - state_dict = { - key: None - for key in missing_types - } + if types: + state_dict = { + key: None + for key in types + } + state_dict.update(results[group]) + else: + state_dict = results[group] + evs = [ state_events[e_id] for e_id in state_ids if e_id in state_events # This can happen if event is rejected. @@ -392,11 +405,11 @@ class StateStore(SQLBaseStore): full=(types is None), ) - results[group] = { + results[group].update({ key: value for key, value in state_dict.items() if value - } + }) defer.returnValue(results)