Always notify replication when a stream advances (#14877)

This ensures that all other workers are told about stream updates in a timely manner, without having to remember to manually poke replication.
This commit is contained in:
Erik Johnston 2023-01-20 18:02:18 +00:00 committed by GitHub
parent cf18fea9e1
commit 65d0386693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 104 additions and 29 deletions

1
changelog.d/14877.misc Normal file
View File

@ -0,0 +1 @@
Always notify replication when a stream advances automatically.

View File

@ -51,6 +51,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
) )
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@ -260,6 +261,9 @@ class MockHomeserver:
def should_send_federation(self) -> bool: def should_send_federation(self) -> bool:
return False return False
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()
class Porter: class Porter:
def __init__( def __init__(

View File

@ -226,8 +226,7 @@ class Notifier:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = [] self.pending_new_room_events: List[_PendingRoomEventEntry] = []
# Called when there are new things to stream over replication self._replication_notifier = hs.get_replication_notifier()
self.replication_callbacks: List[Callable[[], None]] = []
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = [] self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
self._federation_client = hs.get_federation_http_client() self._federation_client = hs.get_federation_http_client()
@ -279,7 +278,7 @@ class Notifier:
it needs to do any asynchronous work, a background thread should be started and it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process. wrapped with run_as_background_process.
""" """
self.replication_callbacks.append(cb) self._replication_notifier.add_replication_callback(cb)
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None: def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room. """Add a callback that will be called when a user joins a room.
@ -741,8 +740,7 @@ class Notifier:
def notify_replication(self) -> None: def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event""" """Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks: self._replication_notifier.notify_replication()
cb()
def notify_user_joined_room(self, event_id: str, room_id: str) -> None: def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks: for cb in self._new_join_in_room_callbacks:
@ -759,3 +757,26 @@ class Notifier:
# Tell the federation client about the fact the server is back up, so # Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried. # that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server) self._federation_client.wake_destination(server)
@attr.s(auto_attribs=True)
class ReplicationNotifier:
"""Tracks callbacks for things that need to know about stream changes.
This is separate from the notifier to avoid circular dependencies.
"""
_replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self._replication_callbacks.append(cb)
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self._replication_callbacks:
cb()

View File

@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.notifier import Notifier from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
@ -389,6 +389,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_notifier(self) -> Notifier: def get_notifier(self) -> Notifier:
return Notifier(self) return Notifier(self)
@cache_in_self
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()
@cache_in_self @cache_in_self
def get_auth(self) -> Auth: def get_auth(self) -> Auth:
return Auth(self) return Auth(self)

View File

@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen = MultiWriterIdGenerator( self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="account_data", stream_name="account_data",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[ tables=[
@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# SQLite). # SQLite).
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"room_account_data", "room_account_data",
"stream_id", "stream_id",
extra_tables=[("room_tags_revisions", "stream_id")], extra_tables=[("room_tags_revisions", "stream_id")],

View File

@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._cache_id_gen = MultiWriterIdGenerator( self._cache_id_gen = MultiWriterIdGenerator(
db_conn, db_conn,
database, database,
notifier=hs.get_replication_notifier(),
stream_name="caches", stream_name="caches",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
tables=[ tables=[

View File

@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
MultiWriterIdGenerator( MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="to_device", stream_name="to_device",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")], tables=[("device_inbox", "instance_name", "stream_id")],
@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
else: else:
self._can_write_to_device = True self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator( self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id" db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()

View File

@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# class below that is used on the main process. # class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"device_lists_stream", "device_lists_stream",
"stream_id", "stream_id",
extra_tables=[ extra_tables=[

View File

@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator( self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id" db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
) )
async def set_e2e_device_keys( async def set_e2e_device_keys(

View File

@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator( self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="events", stream_name="events",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")], tables=[("events", "instance_name", "stream_ordering")],
@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
self._backfill_id_gen = MultiWriterIdGenerator( self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="backfill", stream_name="backfill",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")], tables=[("events", "instance_name", "stream_ordering")],
@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite). # SQLite).
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"events", "events",
"stream_ordering", "stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events, is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
) )
self._backfill_id_gen = StreamIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"events", "events",
"stream_ordering", "stream_ordering",
step=-1, step=-1,
@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream", stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
tables=[ tables=[
@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
) )
else: else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator( self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_event_stream", "stream_id" db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
) )
def get_un_partial_stated_events_token(self) -> int: def get_un_partial_stated_events_token(self) -> int:

View File

@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_id_gen = MultiWriterIdGenerator( self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="presence_stream", stream_name="presence_stream",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")], tables=[("presence_stream", "instance_name", "stream_id")],
@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
) )
else: else:
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
) )
self.hs = hs self.hs = hs

View File

@ -118,6 +118,7 @@ class PushRulesWorkerStore(
# class below that is used on the main process. # class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"push_rules_stream", "push_rules_stream",
"stream_id", "stream_id",
is_writer=hs.config.worker.worker_app is None, is_writer=hs.config.worker.worker_app is None,

View File

@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
# class below that is used on the main process. # class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"pushers", "pushers",
"id", "id",
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],

View File

@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator( self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="receipts", stream_name="receipts",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")], tables=[("receipts_linearized", "instance_name", "stream_id")],
@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# SQLite). # SQLite).
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, db_conn,
hs.get_replication_notifier(),
"receipts_linearized", "receipts_linearized",
"stream_id", "stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,

View File

@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream", stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[ tables=[
@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
) )
else: else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_room_stream", "stream_id" db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
) )
async def store_room( async def store_room(

View File

@ -20,6 +20,7 @@ from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING,
AsyncContextManager, AsyncContextManager,
ContextManager, ContextManager,
Dict, Dict,
@ -49,6 +50,9 @@ from synapse.storage.database import (
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator from synapse.storage.util.sequence import PostgresSequenceGenerator
if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__( def __init__(
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str, table: str,
column: str, column: str,
extra_tables: Iterable[Tuple[str, str]] = (), extra_tables: Iterable[Tuple[str, str]] = (),
@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values. # The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict() self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
self._notifier = notifier
def advance(self, instance_name: str, new_id: int) -> None: def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication # Advance should never be called on a writer instance, only over replication
if self._is_writer: if self._is_writer:
@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
with self._lock: with self._lock:
self._unfinished_ids.pop(next_id) self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.pop(next_id) self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int: def get_current_token(self) -> int:
@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
db: DatabasePool, db: DatabasePool,
notifier: "ReplicationNotifier",
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
tables: List[Tuple[str, str, str]], tables: List[Tuple[str, str, str]],
@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive: bool = True, positive: bool = True,
) -> None: ) -> None:
self._db = db self._db = db
self._notifier = notifier
self._stream_name = stream_name self._stream_name = stream_name
self._instance_name = instance_name self._instance_name = instance_name
self._positive = positive self._positive = positive
@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields # controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids. # a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) return cast(
AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
)
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this # If we have a list of instances that are allowed to write to this
@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
raise Exception("Tried to allocate stream ID on non-writer") raise Exception("Tried to allocate stream ID on non-writer")
# Cast safety: see get_next. # Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) return cast(
AsyncContextManager[List[int]],
_MultiWriterCtxManager(self, self._notifier, n),
)
def get_next_txn(self, txn: LoggingTransaction) -> int: def get_next_txn(self, txn: LoggingTransaction) -> int:
""" """
@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(self._mark_id_as_finished, next_id) txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream # Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't # ID (unless self._writers is not set in which case we don't
@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator""" """Async context manager returned by MultiWriterIdGenerator"""
id_gen: MultiWriterIdGenerator id_gen: MultiWriterIdGenerator
notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list) stream_ids: List[int] = attr.Factory(list)
@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
for i in self.stream_ids: for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i) self.id_gen._mark_id_as_finished(i)
self.notifier.notify_replication()
if exc_type is not None: if exc_type is not None:
return False return False

View File

@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.send_local_online_presence_to([remote_user_id]) self.module_api.send_local_online_presence_to([remote_user_id])
) )
# We don't always send out federation immediately, so we advance the clock.
self.reactor.advance(1000)
# Check that a presence update was sent as part of a federation transaction # Check that a presence update was sent as part of a federation transaction
found_update = False found_update = False
calls = ( calls = (

View File

@ -14,7 +14,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.replication.tcp.commands import PositionCommand, RdataCommand from synapse.replication.tcp.commands import PositionCommand
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
next_token = self.get_success(ctx.__aenter__()) next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None)) self.get_success(ctx.__aexit__(None, None, None))
cmd_handler.send_command(
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
self.get_success( self.get_success(
data_handler.wait_for_stream_position("worker1", "caches", next_token) data_handler.wait_for_stream_position("worker1", "caches", next_token)
) )
# `wait_for_stream_position` should only return once master receives an # `wait_for_stream_position` should only return once master receives a
# RDATA from the worker # notification that `next_token` has persisted.
ctx = cache_id_gen.get_next() ctx_worker1 = cache_id_gen.get_next()
next_token = self.get_success(ctx.__aenter__()) next_token = self.get_success(ctx_worker1.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
d = defer.ensureDeferred( d = defer.ensureDeferred(
data_handler.wait_for_stream_position("worker1", "caches", next_token) data_handler.wait_for_stream_position("worker1", "caches", next_token)
@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
) )
self.assertFalse(d.called) self.assertFalse(d.called)
# ... but receiving the RDATA should # ... but worker1 finishing (and so sending an update) should.
cmd_handler.send_command( self.get_success(ctx_worker1.__aexit__(None, None, None))
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
self.assertTrue(d.called) self.assertTrue(d.called)

View File

@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator( return StreamIdGenerator(
db_conn=conn, db_conn=conn,
notifier=self.hs.get_replication_notifier(),
table="foobar", table="foobar",
column="stream_id", column="stream_id",
) )
@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
conn, conn,
self.db_pool, self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[("foobar", "instance_name", "stream_id")],
@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
conn, conn,
self.db_pool, self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[("foobar", "instance_name", "stream_id")],
@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
conn, conn,
self.db_pool, self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
tables=[ tables=[