Allow moving account data and receipts streams off master (#9104)

This commit is contained in:
Erik Johnston 2021-01-18 15:47:59 +00:00 committed by GitHub
parent f08ef64926
commit 6633a4015a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 854 additions and 279 deletions

1
changelog.d/9104.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for moving off receipts and account data persistence off master.

View File

@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
from synapse.rest.client.v2_alpha import (
account_data,
groups,
read_marker,
receipts,
room_keys,
sync,
tags,
user_directory,
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
receipts.register_servlets(self, resource)
read_marker.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource)

View File

@ -56,6 +56,12 @@ class WriterLocations:
to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
class WorkerConfig(Config):
@ -127,7 +133,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing", "to_device"):
for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
@ -141,6 +147,16 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `to_device` messages."
)
if len(self.writers.account_data) != 1:
raise ConfigError(
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_to_room(
user_id, room_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_for_user(
user_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_tag_to_room(
user_id, room_id, tag, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._add_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
content=content,
)
return response["max_stream_id"]
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_tag_from_room(
user_id, room_id, tag
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._remove_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
)
return response["max_stream_id"]
class AccountDataEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

View File

@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update:
content = {"event_id": event_id}
max_id = await self.store.add_account_data_to_room(
await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View File

@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the receipts stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the receipt EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new:
return
await self.federation.send_read_receipt(receipt)
if self.federation_sender:
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:

View File

@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
await self.store.add_account_data_for_user(
await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
await self.account_data_handler.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
async def update_membership(
self,

View File

@ -15,6 +15,7 @@
from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
devices,
federation,
login,
@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:

View File

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_user_account_data/:user_id/:type
{
"content": { ... },
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
{
"content": { ... },
}
"""
NAME = "add_room_account_data"
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
{
"content": { ... },
}
"""
NAME = "add_tag"
PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, tag):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
"""Remove tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
{}
"""
NAME = "remove_tag"
PATH_ARGS = (
"user_id",
"room_id",
"tag",
)
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag):
return {}
async def _handle_request(self, request, user_id, room_id, tag):
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs, http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server)

View File

@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
tables=[
(
"cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
) # type: Optional[MultiWriterIdGenerator]

View File

@ -15,47 +15,9 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass

View File

@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
pass

View File

@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@ -132,6 +135,22 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, ReceiptsStream):
# Only add ReceiptsStream as a source on the instance in charge of
# receipts.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue

View File

@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
max_id = await self.store.add_account_data_to_room(
await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):

View File

@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}

View File

@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler
@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler())
@cache_in_self
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -160,9 +160,13 @@ class DataStore(
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
tables=[
(
"cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq",
writers=[],
)

View File

@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[
("room_account_data", "instance_name", "stream_id"),
("room_tags_revisions", "instance_name", "stream_id"),
("account_data", "instance_name", "stream_id"),
],
sequence_name="account_data_sequence",
writers=hs.config.worker.writers.account_data,
)
else:
self._can_write_to_account_data = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
else:
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
Returns:
int
"""
raise NotImplementedError()
return self._account_data_id_gen.get_current_token()
@cached()
async def get_account_data_for_user(
@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
)
)
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
Returns:
The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
assert self._can_write_to_account_data
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
assert self._can_write_to_account_data
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"add_user_account_data",
@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
class AccountDataStore(AccountDataWorkerStore):
pass

View File

@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database,
stream_name="to_device",
instance_name=self._instance_name,
table="device_inbox",
instance_column="instance_name",
id_column="stream_id",
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)

View File

@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,),
)
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
def _action_has_highlight(actions):
for action in actions:

View File

@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="events",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="backfill",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,

View File

@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
)
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts,
)
else:
self._can_write_to_receipts = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
else:
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
@abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
raise NotImplementedError()
return self._receipts_id_gen.get_current_token()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
super().__init__(database, db_conn, hs)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
assert self._can_write_to_receipts
res = self.db_pool.simple_select_one_txn(
txn,
table="events",
@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
return None
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type),
)
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph
representations.
"""
assert self._can_write_to_receipts
if not event_ids:
return None
@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
assert self._can_write_to_receipts
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data),
},
)
class ReceiptsStore(ReceiptsWorkerStore):
pass

View File

@ -0,0 +1,20 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
ALTER TABLE account_data ADD COLUMN instance_name TEXT;
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,32 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
-- We need to take the max across all the account_data tables as they share the
-- ID generator
SELECT setval('account_data_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
(SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
(SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
)
));
CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
SELECT setval('receipts_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
));

View File

@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
assert self._can_write_to_account_data
content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id):
sql = (
@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room.
next_id: The the revision to advance to.
"""
assert self._can_write_to_account_data
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass
class TagsStore(TagsWorkerStore):
pass

View File

@ -17,7 +17,7 @@ import logging
import threading
from collections import deque
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
stream_name: A name for the stream.
stream_name: A name for the stream, for use in the `stream_positions`
table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
tables: List of tables associated with the stream. Tuple of table
name, column name that stores the writer's instance name, and
column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
for table, _, id_column in tables:
self._sequence_gen.check_consistency(
db_conn, table=table, id_column=id_column, positive=positive
)
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, table, instance_column, id_column)
self._load_current_ids(db_conn, tables)
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@ -306,6 +306,8 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
max_stream_id = 1
for table, _, id_column in tables:
sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s
@ -316,7 +318,10 @@ class MultiWriterIdGenerator:
}
cur.execute(sql)
(stream_id,) = cur.fetchone()
self._persisted_upto_position = stream_id
max_stream_id = max(max_stream_id, stream_id)
self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@ -329,6 +334,10 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
self._persisted_upto_position = min_stream_id
rows = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s
@ -340,10 +349,13 @@ class MultiWriterIdGenerator:
}
cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id
rows.extend(cur)
# Sort so that we handle rows in order for each instance.
rows.sort()
with self._lock:
for (instance, stream_id,) in cur:
for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)

View File

@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
)
@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers,
positive=False,
@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers,
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
):
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self):
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
self._insert_rows("foobar1", "first", 3)
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)