Handle replication commands synchronously where possible (#7876)

Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
This commit is contained in:
Richard van der Hoff 2020-07-27 18:54:43 +01:00 committed by GitHub
parent 7c2e2c2077
commit f57b99af22
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 86 deletions

View file

@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
self._logging_context = BackgroundProcessLoggingContext(
"replication_command_handler-%s" % self.conn_id
)
ctx_name = "replication-conn-%s" % self.conn_id
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
self._logging_context.request = ctx_name
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)
async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
return an Awaitable).
This allows for protocol level handling of commands (e.g. PINGs), before
delegating to the handler.
Args:
cmd: received command
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
handled = True
if not handled:
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
async def on_PING(self, line):
def on_PING(self, line):
self.received_ping = True
async def on_ERROR(self, cmd):
def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
async def on_NAME(self, cmd):
def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
async def on_SERVER(self, cmd):
def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")