Choose state algorithm based on room version

This commit is contained in:
Erik Johnston 2018-08-08 17:01:57 +01:00
parent 152c0aa58e
commit ce6db0e547
4 changed files with 105 additions and 16 deletions

View File

@ -274,8 +274,9 @@ class FederationHandler(BaseHandler):
ev_ids, get_prev_content=False, check_redacted=False ev_ids, get_prev_content=False, check_redacted=False
) )
room_version = yield self.store.get_room_version(pdu.room_id)
state_map = yield resolve_events_with_factory( state_map = yield resolve_events_with_factory(
state_groups, {pdu.event_id: pdu}, fetch room_version, state_groups, {pdu.event_id: pdu}, fetch
) )
state = (yield self.store.get_events(state_map.values())).values() state = (yield self.store.get_events(state_map.values())).values()
@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
new_state = self.state_handler.resolve_events( room_version = yield self.store.get_room_version(event.room_id)
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
event event
) )

View File

@ -341,9 +341,10 @@ class RoomMemberHandler(object):
prev_events_and_hashes = yield self.store.get_prev_events_for_room( prev_events_and_hashes = yield self.store.get_prev_events_for_room(
room_id, room_id,
) )
latest_event_ids = ( latest_event_ids = [
event_id for (event_id, _, _) in prev_events_and_hashes event_id for (event_id, _, _) in prev_events_and_hashes
) ]
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )

View File

@ -23,16 +23,15 @@ from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, RoomVersions
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import v1
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .v1 import resolve_events_with_factory, resolve_events_with_state_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -263,8 +262,14 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
if event.type == EventTypes.Create:
room_version = event.content.get("room_version", RoomVersions.V1)
else:
room_version = None
entry = yield self.resolve_state_groups_for_events( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
explicit_room_version=room_version,
) )
prev_state_ids = entry.state prev_state_ids = entry.state
@ -332,13 +337,17 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids): def resolve_state_groups_for_events(self, room_id, event_ids,
explicit_room_version=None):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args: Args:
room_id (str): room_id (str)
event_ids (list[str]): event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state Deferred[_StateCacheEntry]: resolved state
@ -364,8 +373,13 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
room_version = explicit_room_version
if not room_version:
room_version = yield self.store.get_room_version(room_id)
result = yield self._state_resolution_handler.resolve_state_groups( result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, None, self._state_map_factory, room_id, room_version, state_groups_ids, None,
self._state_map_factory,
) )
defer.returnValue(result) defer.returnValue(result)
@ -374,7 +388,8 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
) )
def resolve_events(self, state_sets, event): @defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
@ -389,14 +404,18 @@ class StateHandler(object):
for ev in st for ev in st
} }
room_version = yield self.store.get_room_version(event.room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(
room_version, state_set_ids, state_map,
)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state) key: state_map[ev_id] for key, ev_id in iteritems(new_state)
} }
return new_state defer.returnValue(new_state)
class StateResolutionHandler(object): class StateResolutionHandler(object):
@ -429,7 +448,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups( def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory, self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -438,6 +457,7 @@ class StateResolutionHandler(object):
Args: Args:
room_id (str): room we are resolving for (used for logging) room_id (str): room we are resolving for (used for logging)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group map from state group id to the state in that state group
(where 'state' is a map from state key to event id) (where 'state' is a map from state key to event id)
@ -491,6 +511,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
room_version,
list(itervalues(state_groups_ids)), list(itervalues(state_groups_ids)),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_map_factory=state_map_factory,
@ -572,3 +593,64 @@ def _make_state_cache_entry(
prev_group=prev_group, prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,
) )
def resolve_events_with_state_map(room_version, state_sets, state_map):
"""
Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return v1.resolve_events_with_state_map(
state_sets, state_map,
)
else:
# This should only happen if we added a version but forgot to add it to
# the list above.
raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,)
)
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""
Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return v1.resolve_events_with_factory(
state_sets, event_map, state_map_factory,
)
else:
# This should only happen if we added a version but forgot to add it to
# the list above.
raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,)
)

View File

@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
} }
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
room_version = yield self.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups( res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events room_id, room_version, state_groups, events_map, get_events
) )
defer.returnValue((res.state, None)) defer.returnValue((res.state, None))