Fixes and optimisations for resolve_state_groups

First of all, fix the logic which looks for identical input state groups so
that we actually use them. This turned out to be most easily done by factoring
the relevant code out to a separate function so that we could do an early
return.

Secondly, avoid building the whole `conflicted_state` dict (which was only ever
used as a boolean flag).

Thirdly, replace the construction of the `state` dict (which mapped from keys
to events that set them), with an optimistic construction of the resolution
result assuming there will be no conflicts. This should be no slower than
building the old `state` dict, and:
  - in the conflicted case, we'll short-cut it, saving part of the work
  - in the unconflicted case, it saves rebuilding the resolution from the
    `state` dict.

Finally, do a couple of s/values/itervalues/.
This commit is contained in:
Richard van der Hoff 2018-07-23 19:00:16 +01:00
parent f559119de0
commit 5c705f70c9

View File

@ -471,69 +471,39 @@ class StateResolutionHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {}
# start by assuming we won't have any conflicted state, and build up the new
# state map by iterating through the state groups. If we discover a conflict,
# we give up and instead use `resolve_events_with_factory`.
#
# XXX: is this actually worthwhile, or should we just let
# resolve_events_with_factory do it?
new_state = {}
conflicted_state = False
for st in itervalues(state_groups_ids):
for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
for k, v in iteritems(state)
if len(v) > 1
}
if key in new_state:
conflicted_state = True
break
new_state[key] = e_id
if conflicted_state:
break
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
list(state_groups_ids.values()),
list(itervalues(state_groups_ids)),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in iteritems(state)
}
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
with Measure(self.clock, "state.create_group_ids"):
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None
new_state_event_ids = frozenset(itervalues(new_state))
for sg, events in iteritems(state_groups_ids):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
# end up as a prev_group without being added to the database
prev_group = None
delta_ids = None
for old_group, old_ids in iteritems(state_groups_ids):
if not set(new_state) - set(old_ids):
n_delta_ids = {
k: v
for k, v in iteritems(new_state)
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
)
cache = _make_state_cache_entry(new_state, state_groups_ids)
if self._state_cache is not None:
self._state_cache[group_names] = cache
@ -541,6 +511,70 @@ class StateResolutionHandler(object):
defer.returnValue(cache)
def _make_state_cache_entry(
new_state,
state_groups_ids,
):
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
new_state (dict[(str, str), str]): resolved state map (mapping from
(type, state_key) to event_id)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
Returns:
_StateCacheEntry
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
# first look for exact matches
new_state_event_ids = set(itervalues(new_state))
for sg, state in iteritems(state_groups_ids):
if len(new_state_event_ids) != len(state):
continue
old_state_event_ids = set(itervalues(state))
if new_state_event_ids == old_state_event_ids:
# got an exact match.
return _StateCacheEntry(
state=new_state,
state_group=sg,
)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
# end up as a prev_group without being added to the database
# failing that, look for the closest match.
prev_group = None
delta_ids = None
for old_group, old_state in iteritems(state_groups_ids):
n_delta_ids = {
k: v
for k, v in iteritems(new_state)
if old_state.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
return _StateCacheEntry(
state=new_state,
state_group=None,
prev_group=prev_group,
delta_ids=delta_ids,
)
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
@ -582,7 +616,7 @@ def _seperate(state_sets):
with them in different state sets.
Args:
state_sets(list[dict[(str, str), str]]):
state_sets(iterable[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
@ -596,10 +630,11 @@ def _seperate(state_sets):
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
unconflicted_state = dict(state_sets[0])
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
for state_set in state_sets[1:]:
for state_set in state_set_iterator:
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)