From b975fa2e9952f1f8ac2cddb15c287768bf9b0b4e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 10:59:51 -0400 Subject: [PATCH] Convert state resolution to async/await (#7942) --- changelog.d/7942.misc | 1 + synapse/api/auth.py | 12 +- synapse/events/builder.py | 4 +- synapse/federation/sender/__init__.py | 4 +- synapse/handlers/presence.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/state/__init__.py | 95 +++++++--------- synapse/state/v1.py | 15 +-- synapse/state/v2.py | 107 ++++++++---------- synapse/storage/data_stores/main/push_rule.py | 2 +- .../storage/data_stores/main/roommember.py | 2 +- .../data_stores/main/user_directory.py | 4 +- synapse/storage/persist_events.py | 5 +- tests/federation/test_federation_sender.py | 19 ++-- tests/state/test_v2.py | 17 +-- tests/storage/test_room.py | 8 +- tests/test_state.py | 72 +++++++----- tests/test_utils/__init__.py | 7 +- 18 files changed, 198 insertions(+), 184 deletions(-) create mode 100644 changelog.d/7942.misc diff --git a/changelog.d/7942.misc b/changelog.d/7942.misc new file mode 100644 index 000000000..b504cf4e6 --- /dev/null +++ b/changelog.d/7942.misc @@ -0,0 +1 @@ +Convert state resolution to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40dc62ef6..b53e8451e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -127,8 +127,10 @@ class Auth(object): if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id + member = yield defer.ensureDeferred( + self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id + ) ) membership = member.membership if member else None @@ -665,8 +667,10 @@ class Auth(object): ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" + visibility = yield defer.ensureDeferred( + self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) ) if ( visibility diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 92aadfe7e..0bb216419 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,8 +106,8 @@ class EventBuilder(object): Deferred[FrozenEvent] """ - state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids + state_ids = yield defer.ensureDeferred( + self._state.get_current_state_ids(self.room_id, prev_event_ids) ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 99ce73e08..ba4ddd237 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -330,7 +330,9 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield self.state.get_current_hosts_in_room(room_id) + domains = yield defer.ensureDeferred( + self.state.get_current_hosts_in_room(room_id) + ) domains = [ d for d in domains diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8e99c83d9..b3a3bb8c3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 43ffe6faf..472ddf9f7 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -304,7 +304,9 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred( + context.get_current_state_ids() + ) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 495d9f04c..25ccef5aa 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,14 +16,12 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set +from typing import Awaitable, Dict, Iterable, List, Optional, Set import attr from frozendict import frozendict from prometheus_client import Histogram -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase @@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -108,8 +107,7 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def get_current_state( + async def get_current_state( self, room_id, event_type=None, state_key="", latest_event_ids=None ): """ Retrieves the current state for the room. This is done by @@ -126,20 +124,20 @@ class StateHandler(object): map from (type, state_key) to event """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) return event - state_map = yield self.store.get_events( + state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) state = { @@ -148,8 +146,7 @@ class StateHandler(object): return state - @defer.inlineCallbacks - def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids(self, room_id, latest_event_ids=None): """Get the current state, or the state at a set of events, for a room Args: @@ -164,41 +161,38 @@ class StateHandler(object): (event_type, state_key) -> event_id """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state return state - @defer.inlineCallbacks - def get_current_users_in_room(self, room_id, latest_event_ids=None): + async def get_current_users_in_room( + self, room_id: str, latest_event_ids: Optional[List[str]] = None + ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. Args: - room_id (str): The ID of the room. - latest_event_ids (List[str]|None): Precomputed list of latest - event IDs. Will be computed if None. + room_id: The ID of the room. + latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their - profileinfo. + Dictionary of user IDs to their profileinfo. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_users_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_users = await self.store.get_joined_users_from_state(room_id, entry) return joined_users - @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id): - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + async def get_current_hosts_in_room(self, room_id): + event_ids = await self.store.get_latest_event_ids_in_room(room_id) + return await self.get_hosts_in_room_at_events(room_id, event_ids) - @defer.inlineCallbacks - def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events(self, room_id, event_ids): """Get the hosts that were in a room at the given event ids Args: @@ -208,12 +202,11 @@ class StateHandler(object): Returns: Deferred[list[str]]: the hosts in the room at the given events """ - entry = yield self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = yield self.store.get_joined_hosts(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, event_ids) + joined_hosts = await self.store.get_joined_hosts(room_id, entry) return joined_hosts - @defer.inlineCallbacks - def compute_event_context( + async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None ): """Build an EventContext structure for the event. @@ -278,7 +271,7 @@ class StateHandler(object): # otherwise, we'll need to resolve the state across the prev_events. logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups_for_events( + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -295,7 +288,7 @@ class StateHandler(object): # if not state_group_before_event: - state_group_before_event = yield self.state_store.store_state_group( + state_group_before_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -335,7 +328,7 @@ class StateHandler(object): state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = yield self.state_store.store_state_group( + state_group_after_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -353,8 +346,7 @@ class StateHandler(object): ) @measure_func() - @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -373,7 +365,7 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.state_store.get_state_groups_ids( + state_groups_ids = await self.state_store.get_state_groups_ids( room_id, event_ids ) @@ -382,7 +374,7 @@ class StateHandler(object): elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) + prev_group, delta_ids = await self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -391,9 +383,9 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) - result = yield self._state_resolution_handler.resolve_state_groups( + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, @@ -402,8 +394,7 @@ class StateHandler(object): ) return result - @defer.inlineCallbacks - def resolve_events(self, room_version, state_sets, event): + async def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,7 +405,7 @@ class StateHandler(object): state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, event.room_id, room_version, @@ -451,9 +442,8 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - @defer.inlineCallbacks @log_function - def resolve_state_groups( + async def resolve_state_groups( self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -479,13 +469,13 @@ class StateResolutionHandler(object): state_res_store (StateResolutionStore) Returns: - Deferred[_StateCacheEntry]: resolved state + _StateCacheEntry: resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) - with (yield self.resolve_linearizer.queue(group_names)): + with (await self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: @@ -517,7 +507,7 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, room_id, room_version, @@ -598,7 +588,7 @@ def resolve_events_with_store( state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", -): +) -> Awaitable[StateMap[str]]: """ Args: room_id: the room we are working in @@ -619,8 +609,7 @@ def resolve_events_with_store( state_res_store: a place to fetch events from Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 7b531a833..ab5e24841 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,9 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional - -from twisted.internet import defer +from typing import Awaitable, Callable, Dict, List, Optional from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,12 +30,11 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( room_id: str, state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable, + state_map_factory: Callable[[List[str]], Awaitable], ): """ Args: @@ -56,7 +53,7 @@ def resolve_events_with_store( state_map_factory: will be called with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + an Awaitable that resolves to a dict of event_id to event. Returns: Deferred[dict[(str, str), str]]: @@ -80,7 +77,7 @@ def resolve_events_with_store( # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) + state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -110,7 +107,7 @@ def resolve_events_with_store( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) - state_map_new = yield state_map_factory(new_needed_events) + state_map_new = await state_map_factory(new_needed_events) for event in state_map_new.values(): if event.room_id != room_id: raise Exception( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bf6caa094..6634955cd 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from twisted.internet import defer - import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,14 +30,13 @@ from synapse.util import Clock logger = logging.getLogger(__name__) -# We want to yield to the reactor occasionally during state res when dealing +# We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by -# yielding to reactor during loops every N iterations. -_YIELD_AFTER_ITERATIONS = 100 +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, @@ -87,7 +84,7 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( @@ -95,7 +92,7 @@ def resolve_events_with_store( ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -118,14 +115,14 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( + sorted_power_events = await _reverse_topological_power_sort( clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -148,13 +145,13 @@ def resolve_events_with_store( logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( + leftover_events = await _mainline_sort( clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -174,8 +171,7 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. @@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference(state_sets, event_map, state_res_store): """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Deferred[set[str]]: Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) @@ -292,8 +287,7 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( +async def _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event @@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( +async def _reverse_topological_power_sort( clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, @@ -344,26 +337,26 @@ def _reverse_topological_power_sort( graph = {} for idx, event_id in enumerate(event_ids, start=1): - yield _add_event_and_auth_chain_to_graph( + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_to_pl = {} for idx, event_id in enumerate(graph, start=1): - pl = yield _get_power_level_for_sender( + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) def _get_power_order(event_id): ev = event_map[event_id] @@ -378,8 +371,7 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( +async def _iterative_auth_checks( clock, room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the @@ -405,7 +397,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -420,7 +412,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -438,16 +430,15 @@ def _iterative_auth_checks( except AuthError: pass - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) return resolved_state -@defer.inlineCallbacks -def _mainline_sort( +async def _mainline_sort( clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on @@ -474,21 +465,21 @@ def _mainline_sort( idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) idx += 1 @@ -498,23 +489,24 @@ def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): - depth = yield _get_mainline_depth_for_event( + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event, mainline_map, event_map, state_res_store +): """Get the mainline depths for the given event based on the mainline map Args: @@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store @@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d181488db..c22924810 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -259,7 +259,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 29765890e..a92e401e8 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 6b8130bf0..942e51fd3 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa4604167..78fbdcdee 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,7 +29,6 @@ from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap @@ -648,6 +647,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 1a9bd5f37..d1bd18da3 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -26,21 +26,24 @@ from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 38f9b423e..f2955a9c6 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +from typing import List import attr @@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -608,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -622,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -648,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index b1dceb291..1d77b4a2d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f681..4858e8fc5 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -97,17 +97,19 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return state_group + return defer.succeed(state_group) def get_events(self, event_ids, **kwargs): - return { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } + return defer.succeed( + { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } + ) def get_state_group_delta(self, name): - return None, None + return defer.succeed((None, None)) def register_events(self, events): for e in events: @@ -120,7 +122,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version_id(self, room_id): - return RoomVersions.V1.identifier + return defer.succeed(RoomVersions.V1.identifier) class DictObj(dict): @@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) prev_state_ids = yield context.get_prev_state_ids() @@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( + sg1 = yield self.store.store_state_group( prev_event_id_1, event.room_id, None, @@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( + sg2 = yield self.store.store_state_group( prev_event_id_2, event.room_id, None, @@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03b..508aeba07 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result