Move state calculations from rest to handler

This commit is contained in:
Erik Johnston 2016-02-01 15:59:40 +00:00
parent 1ef7cae41b
commit fa48020a52
2 changed files with 97 additions and 140 deletions

View file

@ -23,6 +23,7 @@ from twisted.internet import defer
import collections
import logging
import itertools
logger = logging.getLogger(__name__)
@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
account_data_by_room,
all_ephemeral_by_room,
batch, full_state=False):
if full_state:
state = yield self.get_state_at(room_id, now_token)
elif batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state,
)
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
state = yield self.compute_state_delta(
room_id, batch, sync_config, since_token, now_token,
full_state=full_state
)
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event_id
state_events_delta = yield self.compute_state_delta(
room_id, batch, sync_config, since_token, leave_token,
full_state=full_state
)
if not full_state:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
else:
state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
@ -843,15 +800,18 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
@defer.inlineCallbacks
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
state to compare to
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
new state
:param str room_id
:param TimelineBatch batch
:param sync_config
:param str since_token
:param str now_token
:param bool full_state
:returns A new event dictionary
"""
@ -860,12 +820,50 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
state_delta = {}
for key, event in current_state.iteritems():
if (key not in previous_state or
previous_state[key].event_id != event.event_id):
state_delta[key] = event
return state_delta
if full_state:
if batch:
state = yield self.store.get_state_for_event(batch.events[0].event_id)
else:
state = yield self.get_state_at(
room_id, stream_position=now_token
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta):
"""
@ -912,3 +910,37 @@ def _action_has_highlight(actions):
pass
return False
def _calculate_state(timeline_contains, timeline_start, previous):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if there is an initial sync)
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
)
}
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}