Factor out resolve_state_groups to a separate handler

We extract the storage-independent bits of the state group resolution out to a
separate functiom, and stick it in a new handler, in preparation for its use
from the storage layer.
This commit is contained in:
Richard van der Hoff 2018-01-27 09:15:45 +00:00
parent 0cbda53819
commit 6da4c4d3bd
4 changed files with 108 additions and 54 deletions

View File

@ -66,7 +66,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
@ -102,6 +102,7 @@ class HomeServer(object):
'v1auth', 'v1auth',
'auth', 'auth',
'state_handler', 'state_handler',
'state_resolution_handler',
'presence_handler', 'presence_handler',
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
@ -224,6 +225,9 @@ class HomeServer(object):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_state_resolution_handler(self):
return StateResolutionHandler(self)
def build_presence_handler(self): def build_presence_handler(self):
return PresenceHandler(self) return PresenceHandler(self)

View File

@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass pass

View File

@ -81,31 +81,19 @@ class _StateCacheEntry(object):
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """Fetches bits of state from the stores, and does state resolution
where necessary
""" """
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self): def start_caching(self):
logger.debug("start_caching") # TODO: remove this shim
self._state_resolution_handler.start_caching()
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="", def get_current_state(self, room_id, event_type=None, state_key="",
@ -283,7 +271,6 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def resolve_state_groups_for_events(self, room_id, event_ids): def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -303,13 +290,7 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
logger.debug( if len(state_groups_ids) == 1:
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -321,6 +302,92 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, self._state_map_factory,
)
defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
class StateResolutionHandler(object):
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
@ -351,15 +418,17 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( state_map_factory=state_map_factory,
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
} }
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items(): for sg, events in state_groups_ids.items():
@ -396,30 +465,6 @@ class StateHandler(object):
defer.returnValue(cache) defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock from .utils import MockClock
@ -148,11 +148,13 @@ class StateTestCase(unittest.TestCase):
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock", "get_datastore", "get_auth", "get_state_handler", "get_clock",
"get_state_resolution_handler",
]) ])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None) self.store.get_state_group_delta.return_value = (None, None)