Merge pull request #524 from matrix-org/erikj/sync

Move some sync logic from rest to handlers pacakege
This commit is contained in:
Erik Johnston 2016-01-25 16:58:39 +00:00
commit c887c4cbd5
4 changed files with 178 additions and 93 deletions

View File

@ -190,18 +190,16 @@ class Filter(object):
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
if isinstance(event, dict): sender = event.get("sender", None)
if not sender:
# Presence events have their 'sender' in content.user_id
sender = event.get("content", {}).get("user_id", None)
return self.check_fields( return self.check_fields(
event.get("room_id", None), event.get("room_id", None),
event.get("sender", None), sender,
event.get("type", None), event.get("type", None),
) )
else:
return self.check_fields(
getattr(event, "room_id", None),
getattr(event, "sender", None),
event.type,
)
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [ SyncConfig = collections.namedtuple("SyncConfig", [
"user", "user",
"filter", "filter_collection",
"is_guest", "is_guest",
]) ])
@ -142,8 +142,9 @@ class SyncHandler(BaseHandler):
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token, result = yield self.current_sync_for_user(
full_state=full_state) sync_config, since_token, full_state=full_state,
)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
@ -151,7 +152,7 @@ class SyncHandler(BaseHandler):
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user.to_string(), timeout, current_sync_callback, sync_config.user.to_string(), timeout, current_sync_callback,
from_token=since_token from_token=since_token,
) )
defer.returnValue(result) defer.returnValue(result)
@ -205,7 +206,7 @@ class SyncHandler(BaseHandler):
) )
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (Membership.INVITE, Membership.JOIN)
if sync_config.filter.include_leave: if sync_config.filter_collection.include_leave:
membership_list += (Membership.LEAVE, Membership.BAN) membership_list += (Membership.LEAVE, Membership.BAN)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
@ -266,9 +267,17 @@ class SyncHandler(BaseHandler):
deferreds, consumeErrors=True deferreds, consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -302,14 +311,31 @@ class SyncHandler(BaseHandler):
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
current_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
current_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
defer.returnValue(JoinedSyncResult( defer.returnValue(JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=current_state, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
)) ))
@ -365,7 +391,7 @@ class SyncHandler(BaseHandler):
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -388,7 +414,7 @@ class SyncHandler(BaseHandler):
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -419,13 +445,26 @@ class SyncHandler(BaseHandler):
leave_state = yield self.store.get_state_for_event(leave_event_id) leave_state = yield self.store.get_state_for_event(leave_event_id)
leave_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
leave_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
defer.returnValue(ArchivedSyncResult( defer.returnValue(ArchivedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state, state=leave_state,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -444,7 +483,7 @@ class SyncHandler(BaseHandler):
presence, presence_key = yield presence_source.get_new_events( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter_collection.presence_limit(),
room_ids=room_ids, room_ids=room_ids,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
) )
@ -473,7 +512,7 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
room_events, _ = yield self.store.get_room_events_stream( room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -538,6 +577,27 @@ class SyncHandler(BaseHandler):
# the timeline is inherently limited if we've just joined # the timeline is inherently limited if we've just joined
limited = True limited = True
recents = sync_config.filter_collection.filter_room_timeline(recents)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state.values()
)
}
acc_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
acc_data = sync_config.filter_collection.filter_room_account_data(
acc_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=TimelineBatch( timeline=TimelineBatch(
@ -546,10 +606,8 @@ class SyncHandler(BaseHandler):
limited=limited, limited=limited,
), ),
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=acc_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications={}, unread_notifications={},
) )
logger.debug("Result for room %s: %r", room_id, room_sync) logger.debug("Result for room %s: %r", room_id, room_sync)
@ -603,9 +661,17 @@ class SyncHandler(BaseHandler):
for event in invite_events for event in invite_events
] ]
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
)
presence = sync_config.filter_collection.filter_presence(
presence
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data), account_data=account_data_for_user,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -621,7 +687,7 @@ class SyncHandler(BaseHandler):
limited = True limited = True
recents = [] recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
@ -634,9 +700,9 @@ class SyncHandler(BaseHandler):
from_token=since_token.room_key if since_token else None, from_token=since_token.room_key if since_token else None,
end_token=end_key, end_token=end_key,
) )
(room_key, _) = keys room_key, _ = keys
end_key = "s" + room_key.split('-')[-1] end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events) loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
@ -684,6 +750,7 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
if batch.limited:
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
@ -695,6 +762,11 @@ class SyncHandler(BaseHandler):
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=current_state, current_state=current_state,
) )
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state) just_joined = yield self.check_joined_room(sync_config, state)
if just_joined: if just_joined:
@ -711,14 +783,29 @@ class SyncHandler(BaseHandler):
1 for notif in notifs if _action_has_highlight(notif["actions"]) 1 for notif in notifs if _action_has_highlight(notif["actions"])
]) ])
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral,
account_data=self.account_data_for_room( account_data=account_data,
room_id, tags_by_room, account_data_by_room
),
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
) )
@ -765,13 +852,26 @@ class SyncHandler(BaseHandler):
current_state=state_events_at_leave, current_state=state_events_at_leave,
) )
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room(
leave_event.room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=self.account_data_for_room( account_data=account_data,
leave_event.room_id, tags_by_room, account_data_by_room
),
) )
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)

View File

@ -130,7 +130,7 @@ class SyncRestServlet(RestServlet):
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,
filter=filter, filter_collection=filter,
is_guest=requester.is_guest, is_guest=requester.is_guest,
) )
@ -154,23 +154,21 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, filter, time_now, requester.access_token_id sync_result.joined, time_now, requester.access_token_id
) )
invited = self.encode_invited( invited = self.encode_invited(
sync_result.invited, filter, time_now, requester.access_token_id sync_result.invited, time_now, requester.access_token_id
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, filter, time_now, requester.access_token_id sync_result.archived, time_now, requester.access_token_id
) )
response_content = { response_content = {
"account_data": self.encode_account_data( "account_data": {"events": sync_result.account_data},
sync_result.account_data, filter, time_now
),
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, time_now
), ),
"rooms": { "rooms": {
"join": joined, "join": joined,
@ -182,24 +180,20 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content)) defer.returnValue((200, response_content))
def encode_presence(self, events, filter, time_now): def encode_presence(self, events, time_now):
formatted = [] formatted = []
for event in events: for event in events:
event = copy.deepcopy(event) event = copy.deepcopy(event)
event['sender'] = event['content'].pop('user_id') event['sender'] = event['content'].pop('user_id')
formatted.append(event) formatted.append(event)
return {"events": filter.filter_presence(formatted)} return {"events": formatted}
def encode_account_data(self, events, filter, time_now): def encode_joined(self, rooms, time_now, token_id):
return {"events": filter.filter_account_data(events)}
def encode_joined(self, rooms, filter, time_now, token_id):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -211,18 +205,17 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id room, time_now, token_id
) )
return joined return joined
def encode_invited(self, rooms, filter, time_now, token_id): def encode_invited(self, rooms, time_now, token_id):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -237,7 +230,9 @@ class SyncRestServlet(RestServlet):
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
) )
invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
invited_state = list(unsigned.pop("invite_room_state", []))
invited_state.append(invite) invited_state.append(invite)
invited[room.room_id] = { invited[room.room_id] = {
"invite_state": {"events": invited_state} "invite_state": {"events": invited_state}
@ -245,13 +240,12 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id): def encode_archived(self, rooms, time_now, token_id):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -263,17 +257,16 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False room, time_now, token_id, joined=False
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True):
""" """
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a :param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age :param int time_now: current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing :param int token_id: ID of the user's auth token - used for namespacing
@ -292,19 +285,17 @@ class SyncRestServlet(RestServlet):
) )
state_dict = room.state state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events) timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline( state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events) state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values()) state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
serialized_timeline = [serialize(e) for e in timeline_events] serialized_timeline = [serialize(e) for e in timeline_events]
account_data = filter.filter_room_account_data( account_data = room.account_data
room.account_data
)
result = { result = {
"timeline": { "timeline": {
@ -317,7 +308,7 @@ class SyncRestServlet(RestServlet):
} }
if joined: if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
@ -334,8 +325,6 @@ class SyncRestServlet(RestServlet):
:param list[synapse.events.EventBase] timeline: the event timeline :param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary :return: updated state dictionary
""" """
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy() result = state.copy()
@ -357,7 +346,7 @@ class SyncRestServlet(RestServlet):
# the event graph, and the state is no longer valid. Really, # the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore # the event shouldn't be in the timeline. We're going to ignore
# it for now, however. # it for now, however.
logger.warn("Found state event %r in timeline which doesn't " logger.debug("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event) "match state dictionary", timeline_event)
continue continue

View File

@ -13,26 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, NonCallableMock from mock import Mock
from tests.utils import ( from tests.utils import (
MockHttpResource, DeferredMockCallable, setup_test_homeserver MockHttpResource, DeferredMockCallable, setup_test_homeserver
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.api.filtering import FilterCollection, Filter from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
user_localpart = "test_user" user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id") # MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs): def MockEvent(**kwargs):
ev = NonCallableMock(spec_set=kwargs.keys()) return FrozenEvent(kwargs)
ev.configure_mock(**kwargs)
return ev
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.TestCase):