Reduce amount of state pulled out when querying federation hierachy (#16785)

There are two changes here:

1. Only pull out the required state when handling the request.
2. Change the get filtered state return type to check that we're only
querying state that was requested

---------

Co-authored-by: reivilibre <oliverw@matrix.org>
This commit is contained in:
Erik Johnston 2024-01-10 14:31:35 +00:00 committed by GitHub
parent 4c67f0391b
commit 578c5c736e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 3 deletions

1
changelog.d/16785.misc Normal file
View File

@ -0,0 +1 @@
Reduce amount of state pulled out when querying federation hierachy.

View File

@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection
from synapse.types.state import StateFilter
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@ -546,7 +547,16 @@ class RoomSummaryHandler:
Returns:
True if the room is accessible to the requesting user or server.
"""
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
event_types = [
(EventTypes.JoinRules, ""),
(EventTypes.RoomHistoryVisibility, ""),
]
if requester:
event_types.append((EventTypes.Member, requester))
state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, state_filter=StateFilter.from_types(event_types)
)
# If there's no state for the room, it isn't known.
if not state_ids:

View File

@ -30,7 +30,10 @@ from typing import (
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)
import attr
@ -52,7 +55,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -64,6 +67,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
MAX_STATE_DELTA_HOPS = 100
@ -349,7 +354,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@ -726,3 +732,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@attr.s(auto_attribs=True, slots=True)
class StateMapWrapper(Dict[StateKey, str]):
"""A wrapper around a StateMap[str] to ensure that we only query for items
that were not filtered out.
This is to help prevent bugs where we filter out state but other bits of the
code expect the state to be there.
"""
state_filter: StateFilter
def __getitem__(self, key: StateKey) -> str:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().__getitem__(key)
@overload
def get(self, key: Tuple[str, str]) -> Optional[str]:
...
@overload
def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
...
def get(
self, key: StateKey, default: Union[str, _T, None] = None
) -> Union[str, _T, None]:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().get(key, default)
def __contains__(self, key: Any) -> bool:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().__contains__(key)

View File

@ -20,6 +20,7 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
@ -584,6 +585,29 @@ class StateFilter:
# local users only
return False
def __contains__(self, key: Any) -> bool:
if not isinstance(key, tuple) or len(key) != 2:
raise TypeError(
f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}"
)
typ, state_key = key
if not isinstance(typ, str) or not isinstance(state_key, str):
raise TypeError(
f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})"
)
if typ in self.types:
state_keys = self.types[typ]
if state_keys is None or state_key in state_keys:
return True
elif self.include_others:
return True
return False
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(