Port synapse.replication.tcp to async/await (#6666)

* Port synapse.replication.tcp to async/await

* Newsfile

* Correctly document type of on_<FOO> functions as async

* Don't be overenthusiastic with the asyncing....
This commit is contained in:
Erik Johnston 2020-01-16 09:16:12 +00:00 committed by GitHub
parent 19a1aac48c
commit 48c3a96886
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 80 additions and 105 deletions

View file

@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
from .streams import STREAMS_MAP
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
def handle_command(self, cmd):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd (synapse.replication.tcp.commands.Command): received command
Returns:
Deferred
cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
return handler(cmd)
await handler(cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
def on_PING(self, line):
async def on_PING(self, line):
self.received_ping = True
def on_ERROR(self, cmd):
async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
def on_NAME(self, cmd):
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
def on_USER_SYNC(self, cmd):
return self.streamer.on_user_sync(
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
def on_REPLICATE(self, cmd):
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
for stream in iterkeys(self.streamer.streams_by_name)
]
return make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
return self.subscribe_to_stream(stream_name, token)
await self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
return self.streamer.federation_ack(cmd.token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
def on_INVALIDATE_CACHE(self, cmd):
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_INVALIDATE_CACHE(self, cmd):
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
return self.streamer.on_user_ip(
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
@defer.inlineCallbacks
def subscribe_to_stream(self, stream_name, token):
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
updates, current_token = yield self.streamer.get_stream_updates(
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def on_rdata(self, stream_name, token, rows):
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
Returns:
Deferred|None
"""
raise NotImplementedError()
@abc.abstractmethod
def on_position(self, stream_name, token):
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
def on_SERVER(self, cmd):
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows)
await self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
return self.handler.on_position(cmd.stream_name, cmd.token)
await self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
return self.handler.on_sync(cmd.data)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server