Move connecting logic into ClientReplicationStreamProtocol

This commit is contained in:
Erik Johnston 2019-02-27 10:22:52 +00:00
parent 09fc34c935
commit 6870fc496f
2 changed files with 17 additions and 18 deletions

View File

@ -89,11 +89,6 @@ class ReplicationClientHandler(object):
# Used for tests. # Used for tests.
self.awaiting_syncs = {} self.awaiting_syncs = {}
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = None
# The factory used to create connections. # The factory used to create connections.
self.factory = None self.factory = None
@ -122,12 +117,6 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(stream_name)
if not self.streams_connecting:
self.finished_connecting()
return self.store.process_replication_rows(stream_name, token, []) return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data): def on_sync(self, data):
@ -154,9 +143,6 @@ class ReplicationClientHandler(object):
elif room_account_data: elif room_account_data:
args["account_data"] = room_account_data args["account_data"] = room_account_data
# Record which streams we're in the process of subscribing to
self.streams_connecting = set(args.keys())
return args return args
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
@ -222,10 +208,6 @@ class ReplicationClientHandler(object):
connection.send_command(cmd) connection.send_command(cmd)
self.pending_commands = [] self.pending_commands = []
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.finished_connecting()
def finished_connecting(self): def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all """Called when we have successfully subscribed and caught up to all
streams we're interested in. streams we're interested in.

View File

@ -511,6 +511,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name self.server_name = server_name
self.handler = handler self.handler = handler
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set()
# Map of stream to batched updates. See RdataCommand for info on how # Map of stream to batched updates. See RdataCommand for info on how
# batching works. # batching works.
self.pending_batches = {} self.pending_batches = {}
@ -533,6 +538,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler # We've now finished connecting to so inform the client handler
self.handler.update_connection(self) self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
def on_SERVER(self, cmd): def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@ -562,6 +571,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
return self.handler.on_rdata(stream_name, cmd.token, rows) return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
return self.handler.on_position(cmd.stream_name, cmd.token) return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd): def on_SYNC(self, cmd):
@ -578,6 +593,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.id(), stream_name, token self.id(), stream_name, token
) )
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token)) self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self): def on_connection_closed(self):