Fix state cache

This commit is contained in:
Erik Johnston 2015-08-11 09:12:41 +01:00
parent 017b798e4f
commit 10b874067b

View File

@ -283,6 +283,9 @@ class StateStore(SQLBaseStore):
def _get_state_for_group_from_cache(self, group, types=None): def _get_state_for_group_from_cache(self, group, types=None):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the
list of types that aren't in the cache for that group.
""" """
is_all, state_dict = self._state_group_cache.get(group) is_all, state_dict = self._state_group_cache.get(group)
@ -300,29 +303,31 @@ class StateStore(SQLBaseStore):
if (typ, state_key) not in state_dict: if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key)) missing_types.add((typ, state_key))
if is_all and types is None: if is_all:
return state_dict, missing_types missing_types = set()
if types is None:
return state_dict, set(), True
if is_all or (types is not None and not missing_types): sentinel = object()
sentinel = object()
def include(typ, state_key): def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel) if types is None:
if valid_state_keys is sentinel: return True
return False
if valid_state_keys is None: valid_state_keys = type_to_key.get(typ, sentinel)
return True if valid_state_keys is sentinel:
if state_key in valid_state_keys:
return True
return False return False
if valid_state_keys is None:
return True
if state_key in valid_state_keys:
return True
return False
return { return {
k: v k: v
for k, v in state_dict.items() for k, v in state_dict.items()
if v and include(k[0], k[1]) if include(k[0], k[1])
}, missing_types }, missing_types, not missing_types and types is not None
return {}, missing_types
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None):
@ -333,25 +338,28 @@ class StateStore(SQLBaseStore):
""" """
results = {} results = {}
missing_groups_and_types = [] missing_groups_and_types = []
for group in groups: for group in set(groups):
state_dict, missing_types = self._get_state_for_group_from_cache( state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
group, types group, types
) )
if types is not None and not missing_types: results[group] = state_dict
results[group] = {
key: value if not got_all:
for key, value in state_dict.items()
if value
}
else:
missing_groups_and_types.append(( missing_groups_and_types.append((
group, group,
missing_types if types else None missing_types if types else None
)) ))
if not missing_groups_and_types: if not missing_groups_and_types:
defer.returnValue(results) defer.returnValue({
k: {
key: ev
for key, ev in state.items()
if ev
}
for k, state in results.items()
})
# Okay, so we have some missing_types, lets fetch them. # Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence cache_seq_num = self._state_group_cache.sequence
@ -371,10 +379,15 @@ class StateStore(SQLBaseStore):
} }
for group, state_ids in group_state_dict.items(): for group, state_ids in group_state_dict.items():
state_dict = { if types:
key: None state_dict = {
for key in missing_types key: None
} for key in types
}
state_dict.update(results[group])
else:
state_dict = results[group]
evs = [ evs = [
state_events[e_id] for e_id in state_ids state_events[e_id] for e_id in state_ids
if e_id in state_events # This can happen if event is rejected. if e_id in state_events # This can happen if event is rejected.
@ -392,11 +405,11 @@ class StateStore(SQLBaseStore):
full=(types is None), full=(types is None),
) )
results[group] = { results[group].update({
key: value key: value
for key, value in state_dict.items() for key, value in state_dict.items()
if value if value
} })
defer.returnValue(results) defer.returnValue(results)