diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 9881f068c..ab928a16d 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -26,11 +26,48 @@ import logging logger = logging.getLogger(__name__) -class PusherRestServlet(ClientV1RestServlet): +class PushersRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/pushers$") + + def __init__(self, hs): + super(PushersRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user = requester.user + + pushers = yield self.hs.get_datastore().get_pushers_by_user_id( + user.to_string() + ) + + allowed_keys = [ + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", + ] + + for p in pushers: + for k, v in p.items(): + if k not in allowed_keys: + del p[k] + + defer.returnValue((200, {"pushers": pushers})) + + def on_OPTIONS(self, _): + return 200, {} + + +class PushersSetRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/pushers/set$") def __init__(self, hs): - super(PusherRestServlet, self).__init__(hs) + super(PushersSetRestServlet, self).__init__(hs) self.notifier = hs.get_notifier() @defer.inlineCallbacks @@ -100,4 +137,5 @@ class PusherRestServlet(ClientV1RestServlet): def register_servlets(hs, http_server): - PusherRestServlet(hs).register(http_server) + PushersRestServlet(hs).register(http_server) + PushersSetRestServlet(hs).register(http_server) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 19888a8e7..e64c0dce0 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -56,24 +56,40 @@ class PusherStore(SQLBaseStore): ) defer.returnValue(ret is not None) - @defer.inlineCallbacks def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - def r(txn): - sql = ( - "SELECT * FROM pushers" - " WHERE app_id = ? AND pushkey = ?" - ) + return self.get_pushers_by({ + "app_id": app_id, + "pushkey": pushkey, + }) - txn.execute(sql, (app_id, pushkey,)) - rows = self.cursor_to_dict(txn) + def get_pushers_by_user_id(self, user_id): + return self.get_pushers_by({ + "user_name": user_id, + }) - return self._decode_pushers_rows(rows) - - rows = yield self.runInteraction( - "get_pushers_by_app_id_and_pushkey", r + @defer.inlineCallbacks + def get_pushers_by(self, keyvalues): + ret = yield self._simple_select_list( + "pushers", keyvalues, + [ + "id", + "user_name", + "access_token", + "profile_tag", + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "ts", + "lang", + "data", + "last_stream_ordering", + "last_success", + "failing_since", + ], desc="get_pushers_by" ) - - defer.returnValue(rows) + defer.returnValue(self._decode_pushers_rows(ret)) @defer.inlineCallbacks def get_all_pushers(self):