Track and deduplicate in-flight requests to _get_state_for_groups. (#10870)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
reivilibre 2022-02-18 17:23:31 +00:00 committed by GitHub
parent e6acd3cf4f
commit 284ea2025a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 312 additions and 25 deletions

View file

@ -13,11 +13,23 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -37,7 +55,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.
Used as part of _get_state_for_group_using_inflight_cache.
Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""
inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over
# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
for (request_state_filter, request_deferred) in inflight_requests.items():
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue
reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break
return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Used as part of _get_state_for_group_using_inflight_cache.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.
1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""
# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}
gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)
# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups:
return state
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
async def get_from_cache(group: int, state_filter: StateFilter) -> None:
state[group] = await self._get_state_for_group_using_inflight_cache(
group, state_filter
)
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
await yieldable_gather_results(
get_from_cache,
incomplete_groups,
state_filter,
)
# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
return state
def _get_state_for_groups_using_cache(