Correctly handle the difference between prev and current state

This commit is contained in:
Erik Johnston 2016-08-31 13:55:02 +01:00
parent 1bb8ec296d
commit c10cb581c6
12 changed files with 102 additions and 69 deletions

View file

@ -128,7 +128,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_context(
joined_users = yield self.store.get_joined_users_from_state(
room_id, group, state_ids
)
defer.returnValue(joined_users)
@ -154,27 +154,38 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state_ids = {
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
context.current_state_events = dict(context.prev_state_ids)
key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id
else:
context.current_state_events = context.prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
context.state_group = self.store.get_next_state_group()
defer.returnValue(context)
if old_state:
context.current_state_ids = {
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state_ids:
replaces = context.current_state_ids[key]
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
@ -192,7 +203,7 @@ class StateHandler(object):
group, curr_state = ret
context.current_state_ids = curr_state
context.prev_state_ids = curr_state
if event.is_state() or group is None:
context.state_group = self.store.get_next_state_group()
else:
@ -200,9 +211,13 @@ class StateHandler(object):
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state_ids:
replaces = context.current_state_ids[key]
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)