Use abstract base class to access stream IDs

This commit is contained in:
Erik Johnston 2018-02-20 17:33:18 +00:00
parent f5ac4dc2d4
commit e316bbb4c0
2 changed files with 34 additions and 17 deletions

View File

@ -31,11 +31,16 @@ from synapse.storage.receipts import ReceiptsWorkerStore
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
receipts_id_gen = SlavedIdTracker( # We instansiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) super(SlavedReceiptsStore, self).__init__(db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()

View File

@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
import abc
import logging import logging
import ujson as json import ujson as json
@ -29,21 +30,30 @@ logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, receipts_id_gen, db_conn, hs): """This is an abstract base class where subclasses must implement
""" `get_max_receipt_stream_id` which can be called in the initializer.
Args: """
receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
db_conn: Database connection # This ABCMeta metaclass ensures that we cannot be instantiated without
hs (Homeserver) # the abstract methods being implemented.
""" __metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(db_conn, hs) super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_id_gen = receipts_id_gen
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
) )
@abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
pass
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id): def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read") receipts = yield self.get_receipts_for_room(room_id, "m.read")
@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
} }
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def get_all_updated_receipts(self, last_id, current_id, limit=None): def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return defer.succeed([])
@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore): class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
receipts_id_gen = StreamIdGenerator( # We instansiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) super(ReceiptsStore, self).__init__(db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id): user_id):