Move logic from rest/ to handlers/

This commit is contained in:
Erik Johnston 2016-01-25 10:10:44 +00:00
parent 88baa3865e
commit 4021f95261
3 changed files with 181 additions and 87 deletions

View File

@ -190,18 +190,16 @@ class Filter(object):
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
if isinstance(event, dict): sender = event.get("sender", None)
return self.check_fields( if not sender:
event.get("room_id", None), # Presence events have their 'sender' in content.user_id
event.get("sender", None), sender = event.get("conntent", {}).get("user_id", None)
event.get("type", None),
) return self.check_fields(
else: event.get("room_id", None),
return self.check_fields( sender,
getattr(event, "room_id", None), event.get("type", None),
getattr(event, "sender", None), )
event.type,
)
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.

View File

@ -17,6 +17,7 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from twisted.internet import defer from twisted.internet import defer
@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [ SyncConfig = collections.namedtuple("SyncConfig", [
"user", "user",
"filter", "filter_collection",
"is_guest", "is_guest",
]) ])
@ -129,6 +130,11 @@ class SyncHandler(BaseHandler):
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks
def get_sync_for_user(self, sync_config, since_token=None, timeout=0,
filter_collection=DEFAULT_FILTER_COLLECTION):
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@ -142,8 +148,9 @@ class SyncHandler(BaseHandler):
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token, result = yield self.current_sync_for_user(
full_state=full_state) sync_config, since_token, full_state=full_state,
)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
@ -151,7 +158,7 @@ class SyncHandler(BaseHandler):
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user.to_string(), timeout, current_sync_callback, sync_config.user.to_string(), timeout, current_sync_callback,
from_token=since_token from_token=since_token,
) )
defer.returnValue(result) defer.returnValue(result)
@ -205,7 +212,7 @@ class SyncHandler(BaseHandler):
) )
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (Membership.INVITE, Membership.JOIN)
if sync_config.filter.include_leave: if sync_config.filter_collection.include_leave:
membership_list += (Membership.LEAVE, Membership.BAN) membership_list += (Membership.LEAVE, Membership.BAN)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
@ -266,9 +273,17 @@ class SyncHandler(BaseHandler):
deferreds, consumeErrors=True deferreds, consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -302,14 +317,31 @@ class SyncHandler(BaseHandler):
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
current_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
current_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
defer.returnValue(JoinedSyncResult( defer.returnValue(JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=current_state, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
)) ))
@ -365,7 +397,7 @@ class SyncHandler(BaseHandler):
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -388,7 +420,7 @@ class SyncHandler(BaseHandler):
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -419,13 +451,26 @@ class SyncHandler(BaseHandler):
leave_state = yield self.store.get_state_for_event(leave_event_id) leave_state = yield self.store.get_state_for_event(leave_event_id)
leave_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
leave_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
defer.returnValue(ArchivedSyncResult( defer.returnValue(ArchivedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state, state=leave_state,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -444,7 +489,7 @@ class SyncHandler(BaseHandler):
presence, presence_key = yield presence_source.get_new_events( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter_collection.presence_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -473,7 +518,7 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
room_events, _ = yield self.store.get_room_events_stream( room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -538,6 +583,27 @@ class SyncHandler(BaseHandler):
# the timeline is inherently limited if we've just joined # the timeline is inherently limited if we've just joined
limited = True limited = True
recents = sync_config.filter_collection.filter_room_timeline(recents)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state.values()
)
}
acc_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
acc_data = sync_config.filter_collection.filter_room_account_data(
acc_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=TimelineBatch( timeline=TimelineBatch(
@ -546,10 +612,8 @@ class SyncHandler(BaseHandler):
limited=limited, limited=limited,
), ),
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=acc_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications={}, unread_notifications={},
) )
logger.debug("Result for room %s: %r", room_id, room_sync) logger.debug("Result for room %s: %r", room_id, room_sync)
@ -603,9 +667,17 @@ class SyncHandler(BaseHandler):
for event in invite_events for event in invite_events
] ]
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -621,7 +693,7 @@ class SyncHandler(BaseHandler):
limited = True limited = True
recents = [] recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
@ -634,9 +706,9 @@ class SyncHandler(BaseHandler):
from_token=since_token.room_key if since_token else None, from_token=since_token.room_key if since_token else None,
end_token=end_key, end_token=end_key,
) )
(room_key, _) = keys room_key, _ = keys
end_key = "s" + room_key.split('-')[-1] end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events) loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
@ -684,21 +756,28 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token) if batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
state = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, since_token=since_token,
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=current_state, current_state=current_state,
) )
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state) just_joined = yield self.check_joined_room(sync_config, state)
if just_joined: if just_joined:
state = yield self.get_state_at(room_id, now_token) state = yield self.get_state_at(room_id, now_token)
# batch.limited = True
notifs = yield self.unread_notifs_for_room_id( notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room room_id, sync_config, all_ephemeral_by_room
@ -711,14 +790,29 @@ class SyncHandler(BaseHandler):
1 for notif in notifs if _action_has_highlight(notif["actions"]) 1 for notif in notifs if _action_has_highlight(notif["actions"])
]) ])
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
) )
@ -765,13 +859,26 @@ class SyncHandler(BaseHandler):
current_state=state_events_at_leave, current_state=state_events_at_leave,
) )
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room(
leave_event.room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=self.account_data_for_room( account_data=account_data,
leave_event.room_id, tags_by_room, account_data_by_room
),
) )
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)

View File

@ -130,7 +130,7 @@ class SyncRestServlet(RestServlet):
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,
filter=filter, filter_collection=filter,
is_guest=requester.is_guest, is_guest=requester.is_guest,
) )
@ -154,23 +154,21 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, filter, time_now, requester.access_token_id sync_result.joined, time_now, requester.access_token_id
) )
invited = self.encode_invited( invited = self.encode_invited(
sync_result.invited, filter, time_now, requester.access_token_id sync_result.invited, time_now, requester.access_token_id
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, filter, time_now, requester.access_token_id sync_result.archived, time_now, requester.access_token_id
) )
response_content = { response_content = {
"account_data": self.encode_account_data( "account_data": {"events": sync_result.account_data},
sync_result.account_data, filter, time_now
),
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, time_now
), ),
"rooms": { "rooms": {
"join": joined, "join": joined,
@ -182,24 +180,20 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content)) defer.returnValue((200, response_content))
def encode_presence(self, events, filter, time_now): def encode_presence(self, events, time_now):
formatted = [] formatted = []
for event in events: for event in events:
event = copy.deepcopy(event) event = copy.deepcopy(event)
event['sender'] = event['content'].pop('user_id') event['sender'] = event['content'].pop('user_id')
formatted.append(event) formatted.append(event)
return {"events": filter.filter_presence(formatted)} return {"events": formatted}
def encode_account_data(self, events, filter, time_now): def encode_joined(self, rooms, time_now, token_id):
return {"events": filter.filter_account_data(events)}
def encode_joined(self, rooms, filter, time_now, token_id):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -211,18 +205,17 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id room, time_now, token_id
) )
return joined return joined
def encode_invited(self, rooms, filter, time_now, token_id): def encode_invited(self, rooms, time_now, token_id):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -237,7 +230,9 @@ class SyncRestServlet(RestServlet):
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
) )
invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
invited_state = list(unsigned.pop("invite_room_state", []))
invited_state.append(invite) invited_state.append(invite)
invited[room.room_id] = { invited[room.room_id] = {
"invite_state": {"events": invited_state} "invite_state": {"events": invited_state}
@ -245,13 +240,12 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id): def encode_archived(self, rooms, time_now, token_id):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -263,17 +257,16 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False room, time_now, token_id, joined=False
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True):
""" """
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a :param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -292,19 +285,17 @@ class SyncRestServlet(RestServlet):
) )
state_dict = room.state state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events) timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline( state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events) state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values()) state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
serialized_timeline = [serialize(e) for e in timeline_events] serialized_timeline = [serialize(e) for e in timeline_events]
account_data = filter.filter_room_account_data( account_data = room.account_data
room.account_data
)
result = { result = {
"timeline": { "timeline": {
@ -317,7 +308,7 @@ class SyncRestServlet(RestServlet):
} }
if joined: if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
@ -334,8 +325,6 @@ class SyncRestServlet(RestServlet):
:param list[synapse.events.EventBase] timeline: the event timeline :param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary :return: updated state dictionary
""" """
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy() result = state.copy()
@ -357,8 +346,8 @@ class SyncRestServlet(RestServlet):
# the event graph, and the state is no longer valid. Really, # the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore # the event shouldn't be in the timeline. We're going to ignore
# it for now, however. # it for now, however.
logger.warn("Found state event %r in timeline which doesn't " logger.debug("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event) "match state dictionary", timeline_event)
continue continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None) prev_event_id = timeline_event.unsigned.get("replaces_state", None)