Move catchup of replication streams to worker. (#7024)

This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.
This commit is contained in:
Erik Johnston 2020-03-25 14:54:01 +00:00 committed by GitHub
parent 7bab642707
commit 4cff617df1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 635 additions and 487 deletions

View file

@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@ -53,17 +51,15 @@ import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from typing import Any, DefaultDict, Dict, List, Set
from six import iteritems, iterkeys
from six import iteritems
from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
current_token = self.streamer.get_stream_token(stream_name)
self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set() # type: Set[str]
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, Any]
self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
)
raise
if cmd.token is None:
if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
# Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
if updates:
await self.handler.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token)
# Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
def replicate(self):
"""Send the subscription request to the server
"""
if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,))
logger.info("[%s] Subscribing to replication streams", self.id())
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)