diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5a..0e8a38fd8 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -99,7 +99,8 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReconnectingClientFactory] - # The currently connected connections. + # The currently connected connections. (The list of places we need to send + # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] LaterGauge( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 617e860f9..61dbf4ddb 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import PreserveLoggingContext +from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( Command, @@ -41,8 +41,14 @@ logger = logging.getLogger(__name__) class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. - Parses incoming messages from redis into replication commands, and passes - them to `ReplicationCommandHandler` + This class fulfils two functions: + + (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis + connection, parsing *incoming* messages into replication commands, and passing them + to `ReplicationCommandHandler` + + (b) it implements the AbstractConnection API, where it sends *outgoing* commands + onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom constructor, so instead we expect the defined attributes below to be set @@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Attributes: handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to (not anything to - do with Synapse replication streams). + stream_name: The *redis* stream name to subscribe to and publish from + (not anything to do with Synapse replication streams). outbound_redis_connection: The connection to redis to use to send commands. """ @@ -61,12 +67,22 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): - logger.info("Connected to redis instance") - self.subscribe(self.stream_name) - self.send_command(ReplicateCommand()) - + logger.info("Connected to redis") + run_as_background_process("subscribe-replication", self._send_subscribe) self.handler.new_connection(self) + async def _send_subscribe(self): + # it's important to make sure that we only send the REPLICATE command once we + # have successfully subscribed to the stream - otherwise we might miss the + # POSITION response sent back by the other end. + logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) + await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info( + "Successfully subscribed to redis stream, sending REPLICATE command" + ) + await self._async_send_command(ReplicateCommand()) + logger.info("REPLICATE successfully sent") + def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ @@ -119,7 +135,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): - logger.info("Lost connection to redis instance") + logger.info("Lost connection to redis") self.handler.lost_connection(self) def send_command(self, cmd: Command): @@ -128,6 +144,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ + run_as_background_process("send-cmd", self._send_command, cmd) + + async def _async_send_command(self, cmd: Command): + """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -138,15 +158,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() - async def _send(): - with PreserveLoggingContext(): - # Note that we use the other connection as we can't send - # commands using the subscription connection. - await self.outbound_redis_connection.publish( - self.stream_name, encoded_string - ) - - run_as_background_process("send-cmd", _send) + await make_deferred_yieldable( + self.outbound_redis_connection.publish(self.stream_name, encoded_string) + ) class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):