Merge pull request #573 from matrix-org/erikj/sync_fix

Mitigate against incorrect old state in /sync.
This commit is contained in:
Erik Johnston 2016-02-18 16:40:58 +00:00
commit 220231d8e3

View File

@ -823,15 +823,17 @@ class SyncHandler(BaseHandler):
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
current_state = yield self.get_state_at(
room_id, stream_position=now_token
)
if full_state: if full_state:
if batch: if batch:
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
state = yield self.get_state_at( state = current_state
room_id, stream_position=now_token
)
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event
@ -842,6 +844,7 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state,
previous={}, previous={},
current=current_state,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
@ -861,6 +864,7 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state,
) )
else: else:
state = {} state = {}
@ -920,7 +924,7 @@ def _action_has_highlight(actions):
return False return False
def _calculate_state(timeline_contains, timeline_start, previous): def _calculate_state(timeline_contains, timeline_start, previous, current):
"""Works out what state to include in a sync response. """Works out what state to include in a sync response.
Args: Args:
@ -928,6 +932,7 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_start (dict): state at the start of the timeline timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync) if this is an initial sync)
current (dict): state at the end of the timeline
Returns: Returns:
dict dict
@ -938,14 +943,16 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_contains.values(), timeline_contains.values(),
previous.values(), previous.values(),
timeline_start.values(), timeline_start.values(),
current.values(),
) )
} }
c_ids = set(e.event_id for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids) evs = (event_id_to_state[e] for e in state_ids)
return { return {