Allow moving account data and receipts streams off master (#9104)

This commit is contained in:
Erik Johnston 2021-01-18 15:47:59 +00:00 committed by GitHub
parent f08ef64926
commit 6633a4015a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 854 additions and 279 deletions

View file

@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
)
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts,
)
else:
self._can_write_to_receipts = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
else:
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
@abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
raise NotImplementedError()
return self._receipts_id_gen.get_current_token()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
super().__init__(database, db_conn, hs)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
assert self._can_write_to_receipts
res = self.db_pool.simple_select_one_txn(
txn,
table="events",
@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
return None
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type),
)
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph
representations.
"""
assert self._can_write_to_receipts
if not event_ids:
return None
@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
assert self._can_write_to_receipts
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data),
},
)
class ReceiptsStore(ReceiptsWorkerStore):
pass