Convert state resolution to async/await (#7942)

This commit is contained in:
Patrick Cloke 2020-07-24 10:59:51 -04:00 committed by GitHub
parent e739b20588
commit b975fa2e99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 198 additions and 184 deletions

View file

@ -16,14 +16,12 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def get_current_state(
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = yield self.store.get_events(
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
@ -148,8 +146,7 @@ class StateHandler(object):
return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
room_id (str): The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest
event IDs. Will be computed if None.
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
profileinfo.
Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id):
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
async def get_current_hosts_in_room(self, room_id):
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
@defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
@ -208,12 +202,11 @@ class StateHandler(object):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
@defer.inlineCallbacks
def compute_event_context(
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events(
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@ -295,7 +288,7 @@ class StateHandler(object):
#
if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group(
state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
state_group_after_event = yield self.state_store.store_state_group(
state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@ -353,8 +346,7 @@ class StateHandler(object):
)
@measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.state_store.get_state_groups_ids(
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
room_version = yield self.store.get_room_version_id(room_id)
room_version = await self.store.get_room_version_id(room_id)
result = yield self._state_resolution_handler.resolve_state_groups(
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@ -402,8 +394,7 @@ class StateHandler(object):
)
return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore)
Returns:
Deferred[_StateCacheEntry]: resolved state
_StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)):
with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1: