Back out in-flight state caching changes. (#12126)

This commit is contained in:
reivilibre 2022-03-02 10:37:04 +00:00 committed by GitHub
parent 8e56a1b73c
commit c7b2f1ccdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 26 additions and 676 deletions

View file

@ -13,24 +13,11 @@
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
import attr
from sortedcontainers import SortedDict
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -42,12 +29,6 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -56,8 +37,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -73,24 +54,6 @@ class _GetStateGroupDelta:
return len(self.delta_ids) if self.delta_ids else 0
def state_filter_rough_priority_comparator(
state_filter: StateFilter,
) -> Tuple[int, int]:
"""
Returns a comparable value that roughly indicates the relative size of this
state filter compared to others.
'Larger' state filters should sort first when using ascending order, so
this is essentially the opposite of 'size'.
It should be treated as a rough guide only and should not be interpreted to
have any particular meaning. The representation may also change
The current implementation returns a tuple of the form:
* -1 for include_others, 0 otherwise
* -(number of entries in state_filter.types)
"""
return -int(state_filter.include_others), -len(state_filter.types)
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
@ -143,12 +106,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@ -200,7 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def _get_state_groups_from_groups(
self, groups: Sequence[int], state_filter: StateFilter
self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@ -271,170 +228,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.
If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
and there *still* isn't enough information to complete the request by solely
reusing others, a full state filter will be requested to ensure that subsequent
requests can reuse this request.
Used as part of _get_state_for_group_using_inflight_cache.
Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""
inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over
# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
# Iterate over existing requests in roughly biggest-first order.
for request_state_filter in inflight_requests:
request_deferred = inflight_requests[request_state_filter]
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue
reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break
if (
state_filter_left_over != StateFilter.none()
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
):
# There are too many requests for this group.
# To prevent even more from building up, we request the whole
# state filter to guarantee that we can be reused by any subsequent
# requests for this state group.
return (), StateFilter.all()
return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Used as part of _get_state_for_group_using_inflight_cache.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(
group, SortedDict(state_filter_rough_priority_comparator)
)
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.
1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""
# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}
gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)
# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
@ -476,17 +269,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups:
return state
async def get_from_cache(group: int, state_filter: StateFilter) -> None:
state[group] = await self._get_state_for_group_using_inflight_cache(
group, state_filter
)
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
await yieldable_gather_results(
get_from_cache,
incomplete_groups,
state_filter,
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
)
# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
return state
def _get_state_for_groups_using_cache(