Wake up transaction queue when remote server comes back online (#6706)

This will be used to retry outbound transactions to a remote server if
we think it might have come back up.
This commit is contained in:
Erik Johnston 2020-01-17 10:27:19 +00:00 committed by GitHub
parent 5ce0b17e38
commit a8a50f5b57
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 135 additions and 8 deletions

View file

@ -143,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
if d:
d.callback(data)
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.

View file

@ -387,6 +387,20 @@ class UserIpCommand(Command):
)
class RemoteServerUpCommand(Command):
"""Sent when a worker has detected that a remote server is no longer
"down" and retry timings should be reset.
If sent from a client the server will relay to all other workers.
Format::
REMOTE_SERVER_UP <server>
"""
NAME = "REMOTE_SERVER_UP"
_COMMANDS = (
ServerCommand,
RdataCommand,
@ -401,6 +415,7 @@ _COMMANDS = (
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
SyncCommand.NAME,
RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
)

View file

@ -76,6 +76,7 @@ from synapse.replication.tcp.commands import (
PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
@ -460,6 +461,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
@ -555,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
@ -588,6 +595,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""Called when get a new SYNC command."""
raise NotImplementedError()
@abc.abstractmethod
async def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
raise NotImplementedError()
@abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
@ -707,6 +719,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
"""

View file

@ -120,6 +120,7 @@ class ReplicationStreamer(object):
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@ -288,6 +289,14 @@ class ReplicationStreamer(object):
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.