Don't hold onto full state in state cache (#13324)

This commit is contained in:
Erik Johnston 2022-07-21 16:02:02 +01:00 committed by GitHub
parent 10e4093839
commit 13341dde5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 16 deletions

1
changelog.d/13324.misc Normal file
View File

@ -0,0 +1 @@
Reduce the amount of state we store in the `state_cache`.

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import heapq import heapq
import logging import logging
from collections import defaultdict from collections import ChainMap, defaultdict
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -92,8 +92,11 @@ class _StateCacheEntry:
prev_group: Optional[int] = None, prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None, delta_ids: Optional[StateMap[str]] = None,
): ):
if state is None and state_group is None: if state is None and state_group is None and prev_group is None:
raise Exception("Either state or state group must be not None") raise Exception("One of state, state_group or prev_group must be not None")
if prev_group is not None and delta_ids is None:
raise Exception("If prev_group is set so must delta_ids")
# A map from (type, state_key) to event_id. # A map from (type, state_key) to event_id.
# #
@ -120,18 +123,48 @@ class _StateCacheEntry:
if self._state is not None: if self._state is not None:
return self._state return self._state
assert self.state_group is not None if self.state_group is not None:
return await state_storage.get_state_ids_for_group( return await state_storage.get_state_ids_for_group(
self.state_group, state_filter self.state_group, state_filter
) )
def __len__(self) -> int: assert self.prev_group is not None and self.delta_ids is not None
# The len should is used to estimate how large this cache entry is, for
# cache eviction purposes. This is why if `self.state` is None it's fine
# to return 1.
return len(self._state) if self._state else 1 prev_state = await state_storage.get_state_ids_for_group(
self.prev_group, state_filter
)
# ChainMap expects MutableMapping, but since we're using it immutably
# its safe to give it immutable maps.
return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
def set_state_group(self, state_group: int) -> None:
"""Update the state group assigned to this state (e.g. after we've
persisted it).
Note: this will cause the cache entry to drop any stored state.
"""
self.state_group = state_group
# We clear out the state as we know longer need to explicitly keep it in
# the `state_cache` (as the store state group cache will do that).
self._state = None
def __len__(self) -> int:
# The len should be used to estimate how large this cache entry is, for
# cache eviction purposes. This is why it's fine to return 1 if we're
# not storing any state.
length = 0
if self._state:
length += len(self._state)
if self.delta_ids:
length += len(self.delta_ids)
return length or 1 # Make sure its not 0.
class StateHandler: class StateHandler:
@ -320,7 +353,7 @@ class StateHandler:
current_state_ids=state_ids_before_event, current_state_ids=state_ids_before_event,
) )
) )
entry.state_group = state_group_before_event entry.set_state_group(state_group_before_event)
else: else:
state_group_before_event = entry.state_group state_group_before_event = entry.state_group
@ -747,7 +780,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(state.values()) old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids: if new_state_event_ids == old_state_event_ids:
# got an exact match. # got an exact match.
return _StateCacheEntry(state=new_state, state_group=sg) return _StateCacheEntry(state=None, state_group=sg)
# TODO: We want to create a state group for this set of events, to # TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't # increase cache hits, but we need to make sure that it doesn't
@ -769,9 +802,14 @@ def _make_state_cache_entry(
prev_group = old_group prev_group = old_group
delta_ids = n_delta_ids delta_ids = n_delta_ids
if prev_group is not None:
# If we have a prev group and deltas then we can drop the new state from
# the cache (to reduce memory usage).
return _StateCacheEntry( return _StateCacheEntry(
state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
) )
else:
return _StateCacheEntry(state=new_state, state_group=None)
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)