Update EventContext get_current_event_ids and get_prev_event_ids to accept state filters and update calls where possible (#12791)

This commit is contained in:
Shay 2022-05-20 01:54:12 -07:00 committed by GitHub
parent 2be5a2b07b
commit 71e8afe34d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 65 additions and 18 deletions

View file

@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
@attr.s(slots=True, auto_attribs=True)
@ -196,7 +197,9 @@ class EventContext:
return self._state_group
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
async def get_current_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@ -204,6 +207,9 @@ class EventContext:
not make it into the room state. This method will raise an exception if
``rejected`` is set.
Arg:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
@ -216,7 +222,7 @@ class EventContext:
assert self._state_delta_due_to_event is not None
prev_state_ids = await self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids(state_filter)
if self._state_delta_due_to_event:
prev_state_ids = dict(prev_state_ids)
@ -224,12 +230,17 @@ class EventContext:
return prev_state_ids
async def get_prev_state_ids(self) -> StateMap[str]:
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Args:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Returns {} if state_group is None, which happens when the associated
event is an outlier.
@ -239,7 +250,7 @@ class EventContext:
"""
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event
self.state_group_before_event, state_filter
)