Handle the fact that workers can't generate state groups

This commit is contained in:
Erik Johnston 2016-08-31 15:53:19 +01:00
parent f51888530d
commit ed7a703d4c
2 changed files with 60 additions and 27 deletions

View File

@ -281,11 +281,13 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, curr_state_ids = yield self.state.resolve_state_groups(
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):

View File

@ -43,11 +43,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object):
def __init__(self, state, state_group, ts):
__slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state
self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@ -93,7 +117,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
@ -116,7 +141,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
@ -127,9 +153,9 @@ class StateHandler(object):
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, group, state_ids
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users)
@ -191,23 +217,26 @@ class StateHandler(object):
defer.returnValue(context)
if event.is_state():
ret = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
)
group, curr_state = ret
curr_state = entry.state
context.prev_state_ids = curr_state
if event.is_state() or group is None:
if event.is_state():
context.state_group = self.store.get_next_state_group()
else:
context.state_group = group
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key)
@ -249,16 +278,15 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
defer.returnValue((name, state_list,))
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
cache.ts = self.clock.time_msec()
defer.returnValue(
(cache.state_group, cache.state,)
)
defer.returnValue(cache)
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@ -302,19 +330,22 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
if not state_group:
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None:
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
ts=self.clock.time_msec()
)
if self._state_cache is not None:
self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state,))
defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(