Wait for lazy join to complete when getting current state (#12872)

This commit is contained in:
Erik Johnston 2022-06-01 16:02:53 +01:00 committed by GitHub
parent 782cb7420a
commit 888a29f412
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 361 additions and 82 deletions

View file

@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
NotFoundError if the room is unknown
"""
state_ids = await self.get_current_state_ids(room_id)
state_ids = await self.get_partial_current_state_ids(room_id)
if not state_ids:
raise NotFoundError(f"Current state for room {room_id} is empty")
@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
This may be the partial state if we're lazy joining the room.
Args:
room_id: The room to get the state IDs of.
@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
"get_partial_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
This may be the partial state if we're lazy joining the room.
Args:
room_id
state_filter: The state filter used to fetch state
@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not where_clause:
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
return await self.get_partial_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
room_id: The room ID
Returns:
The canonical alias, if any
"""
state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return None
event = await self.get_event(event_id, allow_none=True)
if not event:
return None
return event.content.get("canonical_alias")
@cached(max_entries=50000)
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(