mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Split ReceiptsStore
This commit is contained in:
parent
324c3e9399
commit
f5ac4dc2d4
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -16,9 +17,7 @@
|
|||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.receipts import ReceiptsWorkerStore
|
||||||
from synapse.storage.receipts import ReceiptsStore
|
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
|
|
||||||
# So, um, we want to borrow a load of functions intended for reading from
|
# So, um, we want to borrow a load of functions intended for reading from
|
||||||
# a DataStore, but we don't want to take functions that either write to the
|
# a DataStore, but we don't want to take functions that either write to the
|
||||||
@ -29,36 +28,14 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|||||||
# the method descriptor on the DataStore and chuck them into our class.
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
class SlavedReceiptsStore(BaseSlavedStore):
|
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
|
receipts_id_gen = SlavedIdTracker(
|
||||||
|
|
||||||
self._receipts_id_gen = SlavedIdTracker(
|
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._receipts_stream_cache = StreamChangeCache(
|
super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
|
||||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
|
||||||
)
|
|
||||||
|
|
||||||
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
|
||||||
get_linearized_receipts_for_room = (
|
|
||||||
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
|
|
||||||
)
|
|
||||||
_get_linearized_receipts_for_rooms = (
|
|
||||||
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
|
|
||||||
)
|
|
||||||
get_last_receipt_event_id_for_user = (
|
|
||||||
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
|
|
||||||
)
|
|
||||||
|
|
||||||
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
|
||||||
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
|
||||||
|
|
||||||
get_linearized_receipts_for_rooms = (
|
|
||||||
DataStore.get_linearized_receipts_for_rooms.__func__
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||||
|
@ -104,9 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
db_conn, "events", "stream_ordering", step=-1,
|
db_conn, "events", "stream_ordering", step=-1,
|
||||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||||
)
|
)
|
||||||
self._receipts_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, "receipts_linearized", "stream_id"
|
|
||||||
)
|
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
self._account_data_id_gen = StreamIdGenerator(
|
||||||
db_conn, "account_data_max_stream_id", "stream_id"
|
db_conn, "account_data_max_stream_id", "stream_id"
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -14,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from .util.id_generators import StreamIdGenerator
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
@ -26,9 +28,17 @@ import ujson as json
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ReceiptsStore(SQLBaseStore):
|
class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, receipts_id_gen, db_conn, hs):
|
||||||
super(ReceiptsStore, self).__init__(db_conn, hs)
|
"""
|
||||||
|
Args:
|
||||||
|
receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
|
||||||
|
db_conn: Database connection
|
||||||
|
hs (Homeserver)
|
||||||
|
"""
|
||||||
|
super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self._receipts_id_gen = receipts_id_gen
|
||||||
|
|
||||||
self._receipts_stream_cache = StreamChangeCache(
|
self._receipts_stream_cache = StreamChangeCache(
|
||||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||||
@ -39,26 +49,6 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
||||||
defer.returnValue(set(r['user_id'] for r in receipts))
|
defer.returnValue(set(r['user_id'] for r in receipts))
|
||||||
|
|
||||||
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
|
||||||
user_id):
|
|
||||||
if receipt_type != "m.read":
|
|
||||||
return
|
|
||||||
|
|
||||||
# Returns an ObservableDeferred
|
|
||||||
res = self.get_users_with_read_receipts_in_room.cache.get(
|
|
||||||
room_id, None, update_metrics=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if res:
|
|
||||||
if isinstance(res, defer.Deferred) and res.called:
|
|
||||||
res = res.result
|
|
||||||
if user_id in res:
|
|
||||||
# We'd only be adding to the set, so no point invalidating if the
|
|
||||||
# user is already there
|
|
||||||
return
|
|
||||||
|
|
||||||
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
def get_receipts_for_room(self, room_id, receipt_type):
|
||||||
return self._simple_select_list(
|
return self._simple_select_list(
|
||||||
@ -273,6 +263,57 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_current_token()
|
return self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_updated_receipts_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
||||||
|
" FROM receipts_linearized"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
)
|
||||||
|
args = [last_id, current_id]
|
||||||
|
if limit is not None:
|
||||||
|
sql += " LIMIT ?"
|
||||||
|
args.append(limit)
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
return txn.fetchall()
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_all_updated_receipts", get_all_updated_receipts_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptsStore(ReceiptsWorkerStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
receipts_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
|
||||||
|
|
||||||
|
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
||||||
|
user_id):
|
||||||
|
if receipt_type != "m.read":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Returns an ObservableDeferred
|
||||||
|
res = self.get_users_with_read_receipts_in_room.cache.get(
|
||||||
|
room_id, None, update_metrics=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
if isinstance(res, defer.Deferred) and res.called:
|
||||||
|
res = res.result
|
||||||
|
if user_id in res:
|
||||||
|
# We'd only be adding to the set, so no point invalidating if the
|
||||||
|
# user is already there
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
||||||
|
|
||||||
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||||
user_id, event_id, data, stream_id):
|
user_id, event_id, data, stream_id):
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
@ -457,25 +498,3 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
"data": json.dumps(data),
|
"data": json.dumps(data),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
|
||||||
if last_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_updated_receipts_txn(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
|
||||||
" FROM receipts_linearized"
|
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
|
||||||
" ORDER BY stream_id ASC"
|
|
||||||
)
|
|
||||||
args = [last_id, current_id]
|
|
||||||
if limit is not None:
|
|
||||||
sql += " LIMIT ?"
|
|
||||||
args.append(limit)
|
|
||||||
txn.execute(sql, args)
|
|
||||||
|
|
||||||
return txn.fetchall()
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_all_updated_receipts", get_all_updated_receipts_txn
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user