mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 12:16:09 -04:00
Port synapse.replication.tcp to async/await (#6666)
* Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_<FOO> functions as async * Don't be overenthusiastic with the asyncing....
This commit is contained in:
parent
19a1aac48c
commit
48c3a96886
15 changed files with 80 additions and 105 deletions
|
@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import (
|
|||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP
|
||||
from synapse.types import Collection
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from .streams import STREAMS_MAP
|
||||
|
||||
connection_close_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||
)
|
||||
|
@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
|
||||
def handle_command(self, cmd):
|
||||
async def handle_command(self, cmd: Command):
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
|
||||
Args:
|
||||
cmd (synapse.replication.tcp.commands.Command): received command
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
cmd: received command
|
||||
"""
|
||||
handler = getattr(self, "on_%s" % (cmd.NAME,))
|
||||
return handler(cmd)
|
||||
await handler(cmd)
|
||||
|
||||
def close(self):
|
||||
logger.warning("[%s] Closing connection", self.id())
|
||||
|
@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
for cmd in pending:
|
||||
self.send_command(cmd)
|
||||
|
||||
def on_PING(self, line):
|
||||
async def on_PING(self, line):
|
||||
self.received_ping = True
|
||||
|
||||
def on_ERROR(self, cmd):
|
||||
async def on_ERROR(self, cmd):
|
||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||
|
||||
def pauseProducing(self):
|
||||
|
@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
self.streamer.new_connection(self)
|
||||
|
||||
def on_NAME(self, cmd):
|
||||
async def on_NAME(self, cmd):
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
def on_USER_SYNC(self, cmd):
|
||||
return self.streamer.on_user_sync(
|
||||
async def on_USER_SYNC(self, cmd):
|
||||
await self.streamer.on_user_sync(
|
||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
def on_REPLICATE(self, cmd):
|
||||
async def on_REPLICATE(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
token = cmd.token
|
||||
|
||||
|
@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
for stream in iterkeys(self.streamer.streams_by_name)
|
||||
]
|
||||
|
||||
return make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
else:
|
||||
return self.subscribe_to_stream(stream_name, token)
|
||||
await self.subscribe_to_stream(stream_name, token)
|
||||
|
||||
def on_FEDERATION_ACK(self, cmd):
|
||||
return self.streamer.federation_ack(cmd.token)
|
||||
async def on_FEDERATION_ACK(self, cmd):
|
||||
self.streamer.federation_ack(cmd.token)
|
||||
|
||||
def on_REMOVE_PUSHER(self, cmd):
|
||||
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
||||
async def on_REMOVE_PUSHER(self, cmd):
|
||||
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
||||
|
||||
def on_INVALIDATE_CACHE(self, cmd):
|
||||
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||
async def on_INVALIDATE_CACHE(self, cmd):
|
||||
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||
|
||||
def on_USER_IP(self, cmd):
|
||||
return self.streamer.on_user_ip(
|
||||
async def on_USER_IP(self, cmd):
|
||||
self.streamer.on_user_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
|
@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
cmd.last_seen,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def subscribe_to_stream(self, stream_name, token):
|
||||
async def subscribe_to_stream(self, stream_name, token):
|
||||
"""Subscribe the remote to a stream.
|
||||
|
||||
This invloves checking if they've missed anything and sending those
|
||||
|
@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
|
||||
try:
|
||||
# Get missing updates
|
||||
updates, current_token = yield self.streamer.get_stream_updates(
|
||||
updates, current_token = await self.streamer.get_stream_updates(
|
||||
stream_name, token
|
||||
)
|
||||
|
||||
|
@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
|
@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
|
||||
Returns:
|
||||
Deferred|None
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_position(self, stream_name, token):
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
def on_SERVER(self, cmd):
|
||||
async def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
||||
def on_RDATA(self, cmd):
|
||||
async def on_RDATA(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
|
@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
return self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
def on_POSITION(self, cmd):
|
||||
async def on_POSITION(self, cmd):
|
||||
# When we get a `POSITION` command it means we've finished getting
|
||||
# missing updates for the given stream, and are now up to date.
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
return self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
def on_SYNC(self, cmd):
|
||||
return self.handler.on_sync(cmd.data)
|
||||
async def on_SYNC(self, cmd):
|
||||
self.handler.on_sync(cmd.data)
|
||||
|
||||
def replicate(self, stream_name, token):
|
||||
"""Send the subscription request to the server
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue