diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b7567b9ea..053ed8480 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -37,8 +37,7 @@ class ReceiptsHandler(BaseHandler): "m.receipt", self._received_remote_receipt ) - # self._earliest_cached_serial = 0 - # self._rooms_to_latest_serial = {} + self._receipt_cache = None @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -162,17 +161,9 @@ class ReceiptEventSource(object): rooms = yield self.store.get_rooms_for_user(user.to_string()) rooms = [room.room_id for room in rooms] - events = [] - for room_id in rooms: - content = yield self.store.get_linearized_receipts_for_room( - room_id, from_key, to_key - ) - if content: - events.append({ - "type": "m.receipt", - "room_id": room_id, - "content": content, - }) + events = yield self.store.get_linearized_receipts_for_rooms( + rooms, from_key, to_key + ) defer.returnValue((events, to_key)) @@ -190,16 +181,8 @@ class ReceiptEventSource(object): rooms = yield self.store.get_rooms_for_user(user.to_string()) rooms = [room.room_id for room in rooms] - events = [] - for room_id in rooms: - content = yield self.store.get_linearized_receipts_for_room( - room_id, from_key, to_key - ) - if content: - events.append({ - "type": "m.receipt", - "room_id": room_id, - "content": content, - }) + events = yield self.store.get_linearized_receipts_for_rooms( + rooms, from_key, to_key + ) defer.returnValue((events, to_key)) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 07f8edaac..503f68f85 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -17,6 +17,9 @@ from ._base import SQLBaseStore, cached from twisted.internet import defer +from synapse.util import unwrapFirstError + +from blist import sorteddict import logging @@ -24,6 +27,29 @@ logger = logging.getLogger(__name__) class ReceiptsStore(SQLBaseStore): + def __init__(self, hs): + super(ReceiptsStore, self).__init__(hs) + + self._receipts_stream_cache = _RoomStreamChangeCache() + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, from_key, to_key): + room_ids = set(room_ids) + + if from_key: + room_ids = yield self._receipts_stream_cache.get_rooms_changed( + self, room_ids, from_key + ) + + results = yield defer.gatherResults( + [ + self.get_linearized_receipts_for_room(room_id, from_key, to_key) + for room_id in room_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue([ev for res in results for ev in res]) @defer.inlineCallbacks def get_linearized_receipts_for_room(self, room_id, from_key, to_key): @@ -57,15 +83,22 @@ class ReceiptsStore(SQLBaseStore): "get_linearized_receipts_for_room", f ) - result = {} + if not rows: + defer.returnValue([]) + + content = {} for row in rows: - result.setdefault( + content.setdefault( row["event_id"], {} ).setdefault( row["receipt_type"], [] ).append(row["user_id"]) - defer.returnValue(result) + defer.returnValue([{ + "type": "m.receipt", + "room_id": room_id, + "content": content, + }]) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_max_token(self) @@ -174,6 +207,9 @@ class ReceiptsStore(SQLBaseStore): stream_id_manager = yield self._receipts_id_gen.get_next(self) with stream_id_manager as stream_id: + yield self._receipts_stream_cache.room_has_changed( + self, room_id, stream_id + ) have_persisted = yield self.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, @@ -223,3 +259,53 @@ class ReceiptsStore(SQLBaseStore): for event_id in event_ids ], ) + + +class _RoomStreamChangeCache(object): + """Keeps track of the stream_id of the latest change in rooms. + + Given a list of rooms and stream key, it will give a subset of rooms that + may have changed since that key. If the key is too old then the cache + will simply return all rooms. + """ + def __init__(self, size_of_cache=1000): + self._size_of_cache = size_of_cache + self._room_to_key = {} + self._cache = sorteddict() + self._earliest_key = None + + @defer.inlineCallbacks + def get_rooms_changed(self, store, room_ids, key): + if key > (yield self._get_earliest_key(store)): + keys = self._cache.keys() + i = keys.bisect_right(key) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(room_ids) + else: + result = room_ids + + defer.returnValue(result) + + @defer.inlineCallbacks + def room_has_changed(self, store, room_id, key): + if key > (yield self._get_earliest_key(store)): + old_key = self._room_to_key.get(room_id, None) + if old_key: + key = max(key, old_key) + self._cache.pop(old_key, None) + self._cache[key] = room_id + + while len(self._cache) > self._size_of_cache: + k, r = self._cache.popitem() + self._earliest_key = max(k, self._earliest_key) + self._room_to_key.pop(r, None) + + @defer.inlineCallbacks + def _get_earliest_key(self, store): + if self._earliest_key is None: + self._earliest_key = yield store.get_max_receipt_stream_id() + self._earliest_key = int(self._earliest_key) + + defer.returnValue(self._earliest_key)