Split resolve_events into two functions

... so that the return type doesn't depend on the arg types
This commit is contained in:
Richard van der Hoff 2018-01-17 15:44:31 +00:00
parent a7e4ff9cca
commit 390093d45e
2 changed files with 29 additions and 20 deletions

View File

@ -341,7 +341,7 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
@ -404,7 +404,7 @@ class StateHandler(object):
} }
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events(state_set_ids, state_map) new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items() key: state_map[ev_id] for key, ev_id in new_state.items()
@ -420,19 +420,17 @@ def _ordered_events(events):
return sorted(events, key=key_func) return sorted(events, key=key_func)
def resolve_events(state_sets, state_map_factory): def resolve_events_with_state_map(state_sets, state_map):
""" """
Args: Args:
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called state_map(dict): a dict from event_id to event, for all events in
with a list of event_ids that are needed, and should return with state_sets.
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from dict[(str, str), synapse.events.FrozenEvent]:
(type, state_key) to event. a map from (type, state_key) to event.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
return state_sets[0] return state_sets[0]
@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets, state_sets,
) )
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
@ -491,8 +482,26 @@ def _seperate(state_sets):
@defer.inlineCallbacks @defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state, def resolve_events_with_factory(state_sets, state_map_factory):
state_map_factory): """
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
a map from (type, state_key) to event.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
needed_events = set( needed_events = set(
event_id event_id
for event_ids in conflicted_state.itervalues() for event_ids in conflicted_state.itervalues()

View File

@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.state import resolve_events from synapse.state import resolve_events_with_factory
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -557,7 +557,7 @@ class EventsStore(SQLBaseStore):
to_return.update(evs) to_return.update(evs)
defer.returnValue(to_return) defer.returnValue(to_return)
current_state = yield resolve_events( current_state = yield resolve_events_with_factory(
state_sets, state_sets,
state_map_factory=get_events, state_map_factory=get_events,
) )