Disable partial state group caching for wildcard lookups

When _get_state_for_groups is given a wildcard filter, just do a complete
lookup. Hopefully this will give us the best of both worlds by not filling up
the ram if we only need one or two keys, but also making the cache still work
for the federation reader usecase.
This commit is contained in:
Richard van der Hoff 2018-06-11 23:13:06 +01:00
parent 240f192523
commit 43e02c409d
3 changed files with 61 additions and 32 deletions

View File

@ -526,10 +526,23 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events """Gets the state at each of a list of state groups, optionally
with matching types. `types` is a list of `(type, state_key)`, where filtering by type/state_key
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned. Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
types (None|iterable[(str, None|str)]):
indicates the state type/keys required. If None, the whole
state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
""" """
if types: if types:
types = frozenset(types) types = frozenset(types)
@ -538,7 +551,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache( state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types group, types,
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -559,22 +572,40 @@ class StateGroupWorkerStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them. # Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence cache_seq_num = self._state_group_cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
# which makes wildcard lookups expensive unless we have a complete
# cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types):
types_to_fetch = None
else:
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types missing_groups, types_to_fetch,
) )
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group] state_dict = results[group]
state_dict.update(group_state_dict)
# update the result, filtering by `types`.
if types:
for k, v in iteritems(group_state_dict):
(typ, _) = k
if k in types or (typ, None) in types:
state_dict[k] = v
else:
state_dict.update(group_state_dict)
# update the cache with all the things we fetched from the
# database.
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,
key=group, key=group,
value=state_dict, value=group_state_dict,
full=(types is None), fetched_keys=types_to_fetch,
known_absent=types,
) )
defer.returnValue(results) defer.returnValue(results)
@ -681,7 +712,6 @@ class StateGroupWorkerStore(SQLBaseStore):
self._state_group_cache.sequence, self._state_group_cache.sequence,
key=state_group, key=state_group,
value=dict(current_state_ids), value=dict(current_state_ids),
full=True,
) )
return state_group return state_group

View File

@ -107,29 +107,28 @@ class DictionaryCache(object):
self.sequence += 1 self.sequence += 1
self.cache.clear() self.cache.clear()
def update(self, sequence, key, value, full=False, known_absent=None): def update(self, sequence, key, value, fetched_keys=None):
"""Updates the entry in the cache """Updates the entry in the cache
Args: Args:
sequence sequence
key key (K)
value (dict): The value to update the cache with. value (dict[X,Y]): The value to update the cache with.
full (bool): Whether the given value is the full dict, or just a fetched_keys (None|set[X]): All of the dictionary keys which were
partial subset there of. If not full then any existing entries fetched from the database.
for the key will be updated.
known_absent (set): Set of keys that we know don't exist in the full If None, this is the complete value for key K. Otherwise, it
dict. is used to infer a list of keys which we know don't exist in
the full dict.
""" """
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
if known_absent is None: if fetched_keys is None:
known_absent = set() self._insert(key, value, set())
if full:
self._insert(key, value, known_absent)
else: else:
self._update_or_insert(key, value, known_absent) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent): def _update_or_insert(self, key, value, known_absent):
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have

View File

@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"} test_value = {"test": "test_simple_cache_hit_full"}
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key) c = self.cache.get(key)
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = { test_value = {
"test": "test_simple_cache_hit_partial" "test": "test_simple_cache_hit_partial"
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"]) c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = { test_value = {
"test": "test_simple_cache_miss_partial" "test": "test_simple_cache_miss_partial"
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value) self.assertEqual({}, c.value)
@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2", "test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3", "test3": "test_simple_cache_hit_miss_partial3",
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = { test_value_1 = {
"test": "test_simple_cache_hit_miss_partial", "test": "test_simple_cache_hit_miss_partial",
} }
self.cache.update(seq, key, test_value_1, full=False) self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence seq = self.cache.sequence
test_value_2 = { test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2", "test2": "test_simple_cache_hit_miss_partial2",
} }
self.cache.update(seq, key, test_value_2, full=False) self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key) c = self.cache.get(key)
self.assertEqual( self.assertEqual(