Split ReceiptsStore

This commit is contained in:
Erik Johnston 2018-02-16 12:17:14 +00:00
parent 324c3e9399
commit f5ac4dc2d4
3 changed files with 69 additions and 76 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,9 +17,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -29,36 +28,14 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs) receipts_id_gen = SlavedIdTracker(
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
self._receipts_stream_cache = StreamChangeCache( super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()

View File

@ -104,9 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1, db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")] extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
) )
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id" db_conn, "account_data_max_stream_id", "stream_id"
) )

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,6 +15,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -26,9 +28,17 @@ import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore): class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, receipts_id_gen, db_conn, hs):
super(ReceiptsStore, self).__init__(db_conn, hs) """
Args:
receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
db_conn: Database connection
hs (Homeserver)
"""
super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_id_gen = receipts_id_gen
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
@ -39,26 +49,6 @@ class ReceiptsStore(SQLBaseStore):
receipts = yield self.get_receipts_for_room(room_id, "m.read") receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts)) defer.returnValue(set(r['user_id'] for r in receipts))
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
if res:
if isinstance(res, defer.Deferred) and res.called:
res = res.result
if user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list( return self._simple_select_list(
@ -273,6 +263,57 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, db_conn, hs):
receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
if res:
if isinstance(res, defer.Deferred) and res.called:
res = res.result
if user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
txn.call_after( txn.call_after(
@ -457,25 +498,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data), "data": json.dumps(data),
} }
) )
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)