Amalgamate _filter_events_for_client

This commit is contained in:
Erik Johnston 2015-10-16 14:52:48 +01:00
parent 5df54de801
commit 366af6b73a
3 changed files with 51 additions and 97 deletions

View File

@ -45,6 +45,52 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events):
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(

View File

@ -146,7 +146,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, room_id, events) events = yield self._filter_events_for_client(user_id, events)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -161,52 +161,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None): token_id=None, txn_id=None):
@ -424,7 +378,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, event.room_id, messages user_id, messages
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -519,7 +473,7 @@ class MessageHandler(BaseHandler):
) )
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, room_id, messages user_id, messages
) )
start_token = StreamToken(token[0], 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0)
@ -599,7 +553,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, room_id, messages user_id, messages
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])

View File

@ -321,52 +321,6 @@ class SyncHandler(BaseHandler):
next_batch=now_token, next_batch=now_token,
)) ))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None):
@ -390,7 +344,7 @@ class SyncHandler(BaseHandler):
end_key = "s" + room_key.split('-')[-1] end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events) loaded_recents = sync_config.filter.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), room_id, loaded_recents, sync_config.user.to_string(), loaded_recents,
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents