mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Replaces all usages of StreamIdGenerator
with MultiWriterIdGenerator
(#17229)
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer.
This commit is contained in:
parent
225f378ffa
commit
d16910ca02
1
changelog.d/17229.misc
Normal file
1
changelog.d/17229.misc
Normal file
@ -0,0 +1 @@
|
||||
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
|
@ -777,22 +777,74 @@ class Porter:
|
||||
await self._setup_events_stream_seqs()
|
||||
await self._setup_sequence(
|
||||
"un_partial_stated_event_stream_sequence",
|
||||
("un_partial_stated_event_stream",),
|
||||
[("un_partial_stated_event_stream", "stream_id")],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
||||
"device_inbox_sequence",
|
||||
[
|
||||
("device_inbox", "stream_id"),
|
||||
("device_federation_outbox", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"account_data_sequence",
|
||||
("room_account_data", "room_tags_revisions", "account_data"),
|
||||
[
|
||||
("room_account_data", "stream_id"),
|
||||
("room_tags_revisions", "stream_id"),
|
||||
("account_data", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"receipts_sequence",
|
||||
[
|
||||
("receipts_linearized", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"presence_stream_sequence",
|
||||
[
|
||||
("presence_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
|
||||
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
|
||||
await self._setup_auth_chain_sequence()
|
||||
await self._setup_sequence(
|
||||
"application_services_txn_id_seq",
|
||||
("application_services_txns",),
|
||||
"txn_id",
|
||||
[
|
||||
(
|
||||
"application_services_txns",
|
||||
"txn_id",
|
||||
)
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_lists_sequence",
|
||||
[
|
||||
("device_lists_stream", "stream_id"),
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"e2e_cross_signing_keys_sequence",
|
||||
[
|
||||
("e2e_cross_signing_keys", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"push_rules_stream_sequence",
|
||||
[
|
||||
("push_rules_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"pushers_sequence",
|
||||
[
|
||||
("pushers", "id"),
|
||||
("deleted_pushers", "stream_id"),
|
||||
],
|
||||
)
|
||||
|
||||
# Step 3. Get tables.
|
||||
@ -1101,12 +1153,11 @@ class Porter:
|
||||
async def _setup_sequence(
|
||||
self,
|
||||
sequence_name: str,
|
||||
stream_id_tables: Iterable[str],
|
||||
column_name: str = "stream_id",
|
||||
stream_id_tables: Iterable[Tuple[str, str]],
|
||||
) -> None:
|
||||
"""Set a sequence to the correct value."""
|
||||
current_stream_ids = []
|
||||
for stream_id_table in stream_id_tables:
|
||||
for stream_id_table, column_name in stream_id_tables:
|
||||
max_stream_id = cast(
|
||||
int,
|
||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
|
@ -57,10 +57,7 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"device_lists_stream",
|
||||
"stream_id",
|
||||
extra_tables=[
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
self._device_list_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="device_lists_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("device_lists_stream", "instance_name", "stream_id"),
|
||||
("user_signature_stream", "instance_name", "stream_id"),
|
||||
("device_lists_outbound_pokes", "instance_name", "stream_id"),
|
||||
("device_lists_changes_in_room", "instance_name", "stream_id"),
|
||||
("device_lists_remote_pending", "instance_name", "stream_id"),
|
||||
],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
sequence_name="device_lists_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
"stream_id": stream_id,
|
||||
"from_user_id": from_user_id,
|
||||
"user_ids": json_encoder.encode(user_ids),
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"device_lists_stream_idx",
|
||||
index_name="device_lists_stream_user_id",
|
||||
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
"device_lists_outbound_pokes",
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see DeviceWorkerStore.__init__)
|
||||
_device_list_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_stream",
|
||||
keys=("stream_id", "user_id", "device_id"),
|
||||
keys=("instance_name", "stream_id", "user_id", "device_id"),
|
||||
values=[
|
||||
(stream_id, user_id, device_id)
|
||||
(self._instance_name, stream_id, user_id, device_id)
|
||||
for stream_id, device_id in zip(stream_ids, device_ids)
|
||||
],
|
||||
)
|
||||
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
values = [
|
||||
(
|
||||
destination,
|
||||
self._instance_name,
|
||||
next(stream_id_iterator),
|
||||
user_id,
|
||||
device_id,
|
||||
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
table="device_lists_outbound_pokes",
|
||||
keys=(
|
||||
"destination",
|
||||
"instance_name",
|
||||
"stream_id",
|
||||
"user_id",
|
||||
"device_id",
|
||||
@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
device_id,
|
||||
{
|
||||
stream_id: destination
|
||||
for (destination, stream_id, _, _, _, _, _) in values
|
||||
for (destination, _, stream_id, _, _, _, _, _) in values
|
||||
},
|
||||
)
|
||||
|
||||
@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
"device_id",
|
||||
"room_id",
|
||||
"stream_id",
|
||||
"instance_name",
|
||||
"converted_to_destinations",
|
||||
"opentracing_context",
|
||||
),
|
||||
@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
device_id,
|
||||
room_id,
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
# We only need to calculate outbound pokes for local users
|
||||
not self.hs.is_mine_id(user_id),
|
||||
encoded_context,
|
||||
@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={"stream_id": stream_id},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
desc="add_remote_device_list_to_pending",
|
||||
)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"e2e_cross_signing_keys",
|
||||
"stream_id",
|
||||
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="e2e_cross_signing_keys",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="e2e_cross_signing_keys_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder, unwrapFirstError
|
||||
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
|
||||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
_push_rules_stream_id_gen: StreamIdGenerator
|
||||
_push_rules_stream_id_gen: MultiWriterIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
|
||||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||
)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"push_rules_stream",
|
||||
"stream_id",
|
||||
is_writer=self._is_push_writer,
|
||||
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="push_rules_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("push_rules_stream", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="push_rules_stream_sequence",
|
||||
writers=hs.config.worker.writers.push_rules,
|
||||
)
|
||||
|
||||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
||||
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
|
||||
raise Exception("Not a push writer")
|
||||
|
||||
values = {
|
||||
"instance_name": self._instance_name,
|
||||
"stream_id": stream_id,
|
||||
"event_stream_ordering": event_stream_ordering,
|
||||
"user_id": user_id,
|
||||
|
@ -40,10 +40,7 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._pushers_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"pushers",
|
||||
"id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._pushers_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="pushers",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("pushers", "instance_name", "id"),
|
||||
("deleted_pushers", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="pushers_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
||||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see PusherWorkerStore.__init__)
|
||||
_pushers_id_gen: AbstractStreamIdGenerator
|
||||
_pushers_id_gen: MultiWriterIdGenerator
|
||||
|
||||
async def add_pusher(
|
||||
self,
|
||||
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"enabled": enabled,
|
||||
"device_id": device_id,
|
||||
# XXX(quenting): We're only really persisting the access token ID
|
||||
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
table="deleted_pushers",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_id": user_id,
|
||||
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="deleted_pushers",
|
||||
keys=("stream_id", "app_id", "pushkey", "user_id"),
|
||||
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
|
||||
values=[
|
||||
(stream_id, pusher.app_id, pusher.pushkey, user_id)
|
||||
(
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
pusher.app_id,
|
||||
pusher.pushkey,
|
||||
user_id,
|
||||
)
|
||||
for stream_id, pusher in zip(stream_ids, pushers)
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,27 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add `instance_name` columns to stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
|
@ -0,0 +1,54 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add squences for stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
|
||||
|
||||
-- We need to take the max across all the device lists tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('device_lists_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
|
||||
)
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
|
||||
|
||||
SELECT setval('e2e_cross_signing_keys_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
|
||||
|
||||
SELECT setval('push_rules_stream_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
|
||||
|
||||
-- We need to take the max across all the pusher tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('pushers_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(id), 1) FROM pushers),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
|
||||
)
|
||||
));
|
@ -23,15 +23,12 @@ import abc
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
||||
|
||||
This class must only be used when the current Synapse process is the sole
|
||||
writer for a stream.
|
||||
|
||||
Args:
|
||||
db_conn(connection): A database connection to use to fetch the
|
||||
initial value of the generator from.
|
||||
table(str): A database table to read the initial value of the id
|
||||
generator from.
|
||||
column(str): The column of the database table to read the initial
|
||||
value from the id generator from.
|
||||
extra_tables(list): List of pairs of database tables and columns to
|
||||
use to source the initial value of the generator from. The value
|
||||
with the largest magnitude is used.
|
||||
step(int): which direction the stream ids grow in. +1 to grow
|
||||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
notifier: "ReplicationNotifier",
|
||||
table: str,
|
||||
column: str,
|
||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||
step: int = 1,
|
||||
is_writer: bool = True,
|
||||
) -> None:
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
self._step: int = step
|
||||
self._current: int = _load_current_id(db_conn, table, column, step)
|
||||
self._is_writer = is_writer
|
||||
for table, column in extra_tables:
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
|
||||
# We use this as an ordered set, as we want to efficiently append items,
|
||||
# remove items and get the first item. Since we insert IDs in order, the
|
||||
# insertion ordering will ensure its in the correct ordering.
|
||||
#
|
||||
# The key and values are the same, but we never look at the values.
|
||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
self._notifier = notifier
|
||||
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
# Advance should never be called on a writer instance, only over replication
|
||||
if self._is_writer:
|
||||
raise Exception("Replication is not supported by writer StreamIdGenerator")
|
||||
|
||||
self._current = (max if self._step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[int, None, None]:
|
||||
try:
|
||||
yield next_id
|
||||
finally:
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
with self._lock:
|
||||
next_ids = range(
|
||||
self._current + self._step,
|
||||
self._current + self._step * (n + 1),
|
||||
self._step,
|
||||
)
|
||||
self._current += n * self._step
|
||||
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[Sequence[int], None, None]:
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
with self._lock:
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
||||
"""
|
||||
Retrieve the next stream ID from within a database transaction.
|
||||
|
||||
Clean-up functions will be called when the transaction finishes.
|
||||
|
||||
Args:
|
||||
txn: The database transaction object.
|
||||
|
||||
Returns:
|
||||
The next stream ID.
|
||||
"""
|
||||
if not self._is_writer:
|
||||
raise Exception("Tried to allocate stream ID on non-writer")
|
||||
|
||||
# Get the next stream ID.
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
def clear_unfinished_id(id_to_clear: int) -> None:
|
||||
"""A function to mark processing this ID as finished"""
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(id_to_clear)
|
||||
|
||||
# Mark this ID as finished once the database transaction itself finishes.
|
||||
txn.call_after(clear_unfinished_id, next_id)
|
||||
txn.call_on_exception(clear_unfinished_id, next_id)
|
||||
|
||||
# Return the new ID.
|
||||
return next_id
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
if not self._is_writer:
|
||||
return self._current
|
||||
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return next(iter(self._unfinished_ids)) - self._step
|
||||
|
||||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
def get_minimal_local_current_token(self) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||
|
||||
|
@ -30,7 +30,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.storage.util.sequence import (
|
||||
LocalSequenceGenerator,
|
||||
PostgresSequenceGenerator,
|
||||
@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class StreamIdGeneratorTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
|
||||
|
||||
def _create_id_generator(self) -> StreamIdGenerator:
|
||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
||||
return StreamIdGenerator(
|
||||
db_conn=conn,
|
||||
notifier=self.hs.get_replication_notifier(),
|
||||
table="foobar",
|
||||
column="stream_id",
|
||||
)
|
||||
|
||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||
|
||||
def test_initial_value(self) -> None:
|
||||
"""Check that we read the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
def test_single_gen_next(self) -> None:
|
||||
"""Check that we correctly increment the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
async with id_gen.get_next() as next_id:
|
||||
# We haven't persisted `next_id` yet; current token is still 123
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
# But we did learn what the next value is
|
||||
self.assertEqual(next_id, 124)
|
||||
|
||||
# Once the context manager closes we assume that the `next_id` has been
|
||||
# written to the DB.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next sensibly."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist each in turn.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 125)
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next, even when their IDs
|
||||
created and persisted in different orders."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist them in a different order, starting with 126 from ctx3.
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
# We haven't persisted 124 from ctx1 yet---current token is still 123.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now persist 124 from ctx1.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
# Current token is then 124, waiting for 125 to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
# Finally persist 125 from ctx2.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Current token is then 126 (skipping over 125).
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request two new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
|
||||
# Persist ctx2 first.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Still waiting on ctx1's ID to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now request a third stream ID. It should be 126 (the smallest ID that
|
||||
# we've not yet handed out.)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
|
||||
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
Loading…
Reference in New Issue
Block a user