Split _get_state_for_group_from_cache into two

This commit is contained in:
Erik Johnston 2015-08-12 17:06:21 +01:00
parent 7b0e797080
commit df361d08f7

View File

@ -287,20 +287,25 @@ class StateStore(SQLBaseStore):
f, f,
) )
def _get_state_for_group_from_cache(self, group, types=None): def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that `missing_types` is the list of types that aren't in the cache for that
group, or None if `types` is None. `got_all` is a bool indicating if group. `got_all` is a bool indicating if we successfully retrieved all
we successfully retrieved all requests state from the cache, if False requests state from the cache, if False we need to query the DB for the
we need to query the DB for the missing state. missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
""" """
is_all, state_dict = self._state_group_cache.get(group) is_all, state_dict = self._state_group_cache.get(group)
type_to_key = {} type_to_key = {}
missing_types = set() missing_types = set()
if types is not None:
for typ, state_key in types: for typ, state_key in types:
if state_key is None: if state_key is None:
type_to_key[typ] = None type_to_key[typ] = None
@ -312,17 +317,9 @@ class StateStore(SQLBaseStore):
if (typ, state_key) not in state_dict: if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key)) missing_types.add((typ, state_key))
if is_all:
missing_types = set()
if types is None:
return state_dict, set(), True
sentinel = object() sentinel = object()
def include(typ, state_key): def include(typ, state_key):
if types is None:
return True
valid_state_keys = type_to_key.get(typ, sentinel) valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel: if valid_state_keys is sentinel:
return False return False
@ -340,6 +337,19 @@ class StateStore(SQLBaseStore):
if include(k[0], k[1]) if include(k[0], k[1])
}, missing_types, got_all }, missing_types, got_all
def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
indicating if we successfully retrieved all requests state from the
cache, if False we need to query the DB for the missing state.
Args:
group: The state group to lookup
"""
is_all, state_dict = self._state_group_cache.get(group)
return state_dict, is_all
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events """Given list of groups returns dict of group -> list of state events
@ -349,8 +359,9 @@ class StateStore(SQLBaseStore):
""" """
results = {} results = {}
missing_groups_and_types = [] missing_groups_and_types = []
if types is not None:
for group in set(groups): for group in set(groups):
state_dict, missing_types, got_all = self._get_state_for_group_from_cache( state_dict, missing_types, got_all = self._get_some_state_from_cache(
group, types group, types
) )
@ -359,8 +370,18 @@ class StateStore(SQLBaseStore):
if not got_all: if not got_all:
missing_groups_and_types.append(( missing_groups_and_types.append((
group, group,
missing_types if types else None missing_types
)) ))
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
group
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, None))
if not missing_groups_and_types: if not missing_groups_and_types:
defer.returnValue({ defer.returnValue({