# -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import ujson from twisted.internet import defer from ._base import SQLBaseStore logger = logging.getLogger(__name__) class DeviceInboxStore(SQLBaseStore): @defer.inlineCallbacks def add_messages_to_device_inbox(self, messages_by_user_then_device): """ Args: messages_by_user_and_device(dict): Dictionary of user_id to device_id to message. Returns: A deferred stream_id that resolves when the messages have been inserted. """ def select_devices_txn(txn, user_id, devices): if not devices: return [] sql = ( "SELECT user_id, device_id FROM devices" " WHERE user_id = ? AND device_id IN (" + ",".join("?" * len(devices)) + ")" ) # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. args = [user_id] + devices txn.execute(sql, args) return [tuple(row) for row in txn.fetchall()] def add_messages_to_device_inbox_txn(txn, stream_id): local_users_and_devices = set() for user_id, messages_by_device in messages_by_user_then_device.items(): local_users_and_devices.update( select_devices_txn(txn, user_id, messages_by_device.keys()) ) sql = ( "INSERT INTO device_inbox" " (user_id, device_id, stream_id, message_json)" " VALUES (?,?,?,?)" ) rows = [] for user_id, messages_by_device in messages_by_user_then_device.items(): for device_id, message in messages_by_device.items(): message_json = ujson.dumps(message) # Only insert into the local inbox if the device exists on # this server if (user_id, device_id) in local_users_and_devices: rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) with self._device_inbox_id_gen.get_next() as stream_id: yield self.runInteraction( "add_messages_to_device_inbox", add_messages_to_device_inbox_txn, stream_id ) defer.returnValue(self._device_inbox_id_gen.get_current_token()) def get_new_messages_for_device( self, user_id, device_id, current_stream_id, limit=100 ): """ Args: user_id(str): The recipient user_id. device_id(str): The recipient device_id. current_stream_id(int): The current position of the to device message stream. Returns: Deferred ([dict], int): List of messages for the device and where in the stream the messages got to. """ def get_new_messages_for_device_txn(txn): sql = ( "SELECT stream_id, message_json FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?" ) txn.execute(sql, (user_id, device_id, current_stream_id, limit)) messages = [] for row in txn.fetchall(): stream_pos = row[0] messages.append(ujson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) return self.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn, ) def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): """ Args: user_id(str): The recipient user_id. device_id(str): The recipient device_id. up_to_stream_id(int): Where to delete messages up to. Returns: A deferred that resolves when the messages have been deleted. """ def delete_messages_for_device_txn(txn): sql = ( "DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?" ) txn.execute(sql, (user_id, device_id, up_to_stream_id)) return self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: last_pos(int): current_pos(int): limit(int): Returns: A deferred list of rows from the device inbox """ if last_pos == current_pos: return defer.succeed([]) def get_all_new_device_messages_txn(txn): sql = ( "SELECT stream_id FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY stream_id" " ORDER BY stream_id ASC" " LIMIT ?" ) txn.execute(sql, (last_pos, current_pos, limit)) stream_ids = txn.fetchall() if not stream_ids: return [] max_stream_id_in_limit = stream_ids[-1] sql = ( "SELECT stream_id, user_id, device_id, message_json" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) txn.execute(sql, (last_pos, max_stream_id_in_limit)) return txn.fetchall() return self.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token()