From f28cc4518368955f692a1262a688e907cb22d226 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 16:01:40 +0000 Subject: [PATCH] Pass in current state to push action handler --- synapse/handlers/_base.py | 31 ++++++++++-------------- synapse/push/action_generator.py | 6 +++-- synapse/push/bulk_push_rule_evaluator.py | 10 +++----- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index fa83d3e46..d3f722b22 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -53,25 +53,10 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_clients(self, user_tuples, events): + def _filter_events_for_clients(self, user_tuples, events, event_id_to_state): """ Returns dict of user_id -> list of events that user is allowed to see. """ - # If there is only one user, just get the state for that one user, - # otherwise just get all the state. - if len(user_tuples) == 1: - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_tuples[0][0]), - ) - else: - types = None - - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=types - ) - forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( room_id, @@ -135,7 +120,17 @@ class BaseHandler(object): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, events, is_peeking=False): # Assumes that user has at some point joined the room if not is_guest. - res = yield self._filter_events_for_clients([(user_id, is_peeking)], events) + types = ( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=types + ) + res = yield self._filter_events_for_clients( + [(user_id, is_peeking)], events, event_id_to_state + ) defer.returnValue(res.get(user_id, [])) def ratelimit(self, user_id): @@ -275,7 +270,7 @@ class BaseHandler(object): action_generator = ActionGenerator(self.hs) yield action_generator.handle_push_actions_for_event( - event, self + event, self, context.current_state ) destinations = set() diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1d2e558f9..d8f8256a1 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -36,7 +36,7 @@ class ActionGenerator: # tag (ie. we just need all the users). @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, handler): + def handle_push_actions_for_event(self, event, handler, current_state): if event.type == EventTypes.Redaction and event.redacts is not None: yield self.store.remove_push_actions_for_event_id( event.room_id, event.redacts @@ -46,7 +46,9 @@ class ActionGenerator: event.room_id, self.hs, self.store ) - actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler) + actions_by_user = yield bulk_evaluator.action_for_event_by_user( + event, handler, current_state + ) yield self.store.set_push_actions_for_event_and_users( event, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 20c60422b..8ac5ceb9e 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -98,25 +98,21 @@ class BulkPushRuleEvaluator: self.store = store @defer.inlineCallbacks - def action_for_event_by_user(self, event, handler): + def action_for_event_by_user(self, event, handler, current_state): actions_by_user = {} users_dict = yield self.store.are_guests(self.rules_by_user.keys()) filtered_by_user = yield handler._filter_events_for_clients( - users_dict.items(), [event] + users_dict.items(), [event], {event.event_id: current_state} ) evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room)) condition_cache = {} - member_state = yield self.store.get_state_for_event( - event.event_id, - ) - display_names = {} - for ev in member_state.values(): + for ev in current_state.values(): nm = ev.content.get("displayname", None) if nm and ev.type == EventTypes.Member: display_names[ev.state_key] = nm