Fix state cache

This commit is contained in:
Erik Johnston 2015-08-11 09:12:41 +01:00
parent 017b798e4f
commit 10b874067b

View File

@ -283,6 +283,9 @@ class StateStore(SQLBaseStore):
def _get_state_for_group_from_cache(self, group, types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the
list of types that aren't in the cache for that group.
"""
is_all, state_dict = self._state_group_cache.get(group)
@ -300,13 +303,17 @@ class StateStore(SQLBaseStore):
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
if is_all and types is None:
return state_dict, missing_types
if is_all:
missing_types = set()
if types is None:
return state_dict, set(), True
if is_all or (types is not None and not missing_types):
sentinel = object()
def include(typ, state_key):
if types is None:
return True
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
@ -319,10 +326,8 @@ class StateStore(SQLBaseStore):
return {
k: v
for k, v in state_dict.items()
if v and include(k[0], k[1])
}, missing_types
return {}, missing_types
if include(k[0], k[1])
}, missing_types, not missing_types and types is not None
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
@ -333,25 +338,28 @@ class StateStore(SQLBaseStore):
"""
results = {}
missing_groups_and_types = []
for group in groups:
state_dict, missing_types = self._get_state_for_group_from_cache(
for group in set(groups):
state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
group, types
)
if types is not None and not missing_types:
results[group] = {
key: value
for key, value in state_dict.items()
if value
}
else:
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((
group,
missing_types if types else None
))
if not missing_groups_and_types:
defer.returnValue(results)
defer.returnValue({
k: {
key: ev
for key, ev in state.items()
if ev
}
for k, state in results.items()
})
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
@ -371,10 +379,15 @@ class StateStore(SQLBaseStore):
}
for group, state_ids in group_state_dict.items():
if types:
state_dict = {
key: None
for key in missing_types
for key in types
}
state_dict.update(results[group])
else:
state_dict = results[group]
evs = [
state_events[e_id] for e_id in state_ids
if e_id in state_events # This can happen if event is rejected.
@ -392,11 +405,11 @@ class StateStore(SQLBaseStore):
full=(types is None),
)
results[group] = {
results[group].update({
key: value
for key, value in state_dict.items()
if value
}
})
defer.returnValue(results)