From ec0a523ac338bab1eb23a6b21227b8f7402cc2d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:37:18 +0000 Subject: [PATCH] Split out static state methods from StateHandler --- synapse/state.py | 142 ++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a8..c75499c3e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,6 +16,7 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure @@ -335,9 +336,10 @@ class StateHandler(object): [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) + with Measure(self.clock, "state._resolve_events"): + new_state, _ = Resolver.resolve_events( + state_sets, event_type, state_key + ) new_state = { key: e.event_id for key, e in new_state.items() } @@ -388,68 +390,78 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + with Measure(self.clock, "state._resolve_events"): + if event.is_state(): + return Resolver.resolve_events( + state_sets, event.type, event.state_key + ) + else: + return Resolver.resolve_events(state_sets) - def _resolve_events(self, state_sets, event_type=None, state_key=""): + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) + + +class Resolver(object): + @staticmethod + def resolve_events(state_sets, event_type=None, state_key=""): """ Returns (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple (new_state, prev_states). new_state is a map from (type, state_key) to event. prev_states is a list of event_ids. """ - with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + try: + resolved_state = Resolver._resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - new_state = unconflicted_state - new_state.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): + @staticmethod + def _resolve_state_events(conflicted_state, auth_events): """ This is where we actually decide which of the conflicted state to use. @@ -464,7 +476,7 @@ class StateHandler(object): if power_key in conflicted_state: events = conflicted_state[power_key] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( + resolved_state[power_key] = Resolver._resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -472,7 +484,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -482,7 +494,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -492,14 +504,15 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( + resolved_state[key] = Resolver._resolve_normal_events( events, auth_events ) return resolved_state - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] + @staticmethod + def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] auth_events = dict(auth_events) @@ -507,23 +520,20 @@ class StateHandler(object): for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event return event - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): + @staticmethod + def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) return event except AuthError: pass @@ -531,9 +541,3 @@ class StateHandler(object): # Use the last event (the one with the least depth) if they all fail # the auth check. return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func)