mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 07:34:19 -05:00
Port synapse.replication.tcp to async/await (#6666)
* Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_<FOO> functions as async * Don't be overenthusiastic with the asyncing....
This commit is contained in:
parent
19a1aac48c
commit
48c3a96886
1
changelog.d/6666.misc
Normal file
1
changelog.d/6666.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Port `synapse.replication.tcp` to async/await.
|
@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer):
|
|||||||
|
|
||||||
|
|
||||||
class AdminCmdReplicationHandler(ReplicationClientHandler):
|
class AdminCmdReplicationHandler(ReplicationClientHandler):
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
def get_streams_to_replicate(self):
|
||||||
|
@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler):
|
|||||||
super(ASReplicationHandler, self).__init__(hs.get_datastore())
|
super(ASReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
|
|
||||||
|
|
||||||
if stream_name == "events":
|
if stream_name == "events":
|
||||||
max_stream_id = self.store.get_room_max_stream_ordering()
|
max_stream_id = self.store.get_room_max_stream_ordering()
|
||||||
|
@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
|
|||||||
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
|
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.send_handler = FederationSenderHandler(hs, self)
|
self.send_handler = FederationSenderHandler(hs, self)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
await super(FederationSenderReplicationHandler, self).on_rdata(
|
||||||
yield super(FederationSenderReplicationHandler, self).on_rdata(
|
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||||
|
@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
|||||||
|
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
|
|
||||||
run_in_background(self.poke_pushers, stream_name, token, rows)
|
run_in_background(self.poke_pushers, stream_name, token, rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
|||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
|
||||||
run_in_background(self.process_and_notify, stream_name, token, rows)
|
run_in_background(self.process_and_notify, stream_name, token, rows)
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
def get_streams_to_replicate(self):
|
||||||
|
@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
|
|||||||
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
|
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
|
||||||
self.user_directory = hs.get_user_directory_handler()
|
self.user_directory = hs.get_user_directory_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
def on_rdata(self, stream_name, token, rows):
|
await super(UserDirectoryReplicationHandler, self).on_rdata(
|
||||||
yield super(UserDirectoryReplicationHandler, self).on_rdata(
|
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
if stream_name == EventsStream.NAME:
|
if stream_name == EventsStream.NAME:
|
||||||
|
@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object):
|
|||||||
def federation_ack(self, token):
|
def federation_ack(self, token):
|
||||||
self._clear_queue_before_pos(token)
|
self._clear_queue_before_pos(token)
|
||||||
|
|
||||||
def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
|
async def get_replication_rows(
|
||||||
|
self, from_token, to_token, limit, federation_ack=None
|
||||||
|
):
|
||||||
"""Get rows to be sent over federation between the two tokens
|
"""Get rows to be sent over federation between the two tokens
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -257,7 +257,7 @@ class TypingHandler(object):
|
|||||||
"typing_key", self._latest_room_serial, rooms=[member.room_id]
|
"typing_key", self._latest_room_serial, rooms=[member.room_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_typing_updates(self, last_id, current_id):
|
async def get_all_typing_updates(self, last_id, current_id):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
|
|||||||
port = hs.config.worker_replication_port
|
port = hs.config.worker_replication_port
|
||||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||||
|
|
||||||
def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||||
@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
|
|||||||
token (int): stream token for this batch of rows
|
token (int): stream token for this batch of rows
|
||||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||||
Stream.parse_row.
|
Stream.parse_row.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred|None
|
|
||||||
"""
|
"""
|
||||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||||
return self.store.process_replication_rows(stream_name, token, rows)
|
self.store.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
def on_position(self, stream_name, token):
|
async def on_position(self, stream_name, token):
|
||||||
"""Called when we get new position data. By default this just pokes
|
"""Called when we get new position data. By default this just pokes
|
||||||
the slave store.
|
the slave store.
|
||||||
|
|
||||||
Can be overriden in subclasses to handle more.
|
Can be overriden in subclasses to handle more.
|
||||||
"""
|
"""
|
||||||
return self.store.process_replication_rows(stream_name, token, [])
|
self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
def on_sync(self, data):
|
def on_sync(self, data):
|
||||||
"""When we received a SYNC we wake up any deferreds that were waiting
|
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||||
|
@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import (
|
|||||||
SyncCommand,
|
SyncCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
|
from synapse.replication.tcp.streams import STREAMS_MAP
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from .streams import STREAMS_MAP
|
|
||||||
|
|
||||||
connection_close_counter = Counter(
|
connection_close_counter = Counter(
|
||||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||||
)
|
)
|
||||||
@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_command(self, cmd):
|
async def handle_command(self, cmd: Command):
|
||||||
"""Handle a command we have received over the replication stream.
|
"""Handle a command we have received over the replication stream.
|
||||||
|
|
||||||
By default delegates to on_<COMMAND>
|
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cmd (synapse.replication.tcp.commands.Command): received command
|
cmd: received command
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
handler = getattr(self, "on_%s" % (cmd.NAME,))
|
handler = getattr(self, "on_%s" % (cmd.NAME,))
|
||||||
return handler(cmd)
|
await handler(cmd)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
logger.warning("[%s] Closing connection", self.id())
|
logger.warning("[%s] Closing connection", self.id())
|
||||||
@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
for cmd in pending:
|
for cmd in pending:
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
def on_PING(self, line):
|
async def on_PING(self, line):
|
||||||
self.received_ping = True
|
self.received_ping = True
|
||||||
|
|
||||||
def on_ERROR(self, cmd):
|
async def on_ERROR(self, cmd):
|
||||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self):
|
||||||
@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
BaseReplicationStreamProtocol.connectionMade(self)
|
BaseReplicationStreamProtocol.connectionMade(self)
|
||||||
self.streamer.new_connection(self)
|
self.streamer.new_connection(self)
|
||||||
|
|
||||||
def on_NAME(self, cmd):
|
async def on_NAME(self, cmd):
|
||||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||||
self.name = cmd.data
|
self.name = cmd.data
|
||||||
|
|
||||||
def on_USER_SYNC(self, cmd):
|
async def on_USER_SYNC(self, cmd):
|
||||||
return self.streamer.on_user_sync(
|
await self.streamer.on_user_sync(
|
||||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_REPLICATE(self, cmd):
|
async def on_REPLICATE(self, cmd):
|
||||||
stream_name = cmd.stream_name
|
stream_name = cmd.stream_name
|
||||||
token = cmd.token
|
token = cmd.token
|
||||||
|
|
||||||
@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
for stream in iterkeys(self.streamer.streams_by_name)
|
for stream in iterkeys(self.streamer.streams_by_name)
|
||||||
]
|
]
|
||||||
|
|
||||||
return make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.subscribe_to_stream(stream_name, token)
|
await self.subscribe_to_stream(stream_name, token)
|
||||||
|
|
||||||
def on_FEDERATION_ACK(self, cmd):
|
async def on_FEDERATION_ACK(self, cmd):
|
||||||
return self.streamer.federation_ack(cmd.token)
|
self.streamer.federation_ack(cmd.token)
|
||||||
|
|
||||||
def on_REMOVE_PUSHER(self, cmd):
|
async def on_REMOVE_PUSHER(self, cmd):
|
||||||
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
||||||
|
|
||||||
def on_INVALIDATE_CACHE(self, cmd):
|
async def on_INVALIDATE_CACHE(self, cmd):
|
||||||
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||||
|
|
||||||
def on_USER_IP(self, cmd):
|
async def on_USER_IP(self, cmd):
|
||||||
return self.streamer.on_user_ip(
|
self.streamer.on_user_ip(
|
||||||
cmd.user_id,
|
cmd.user_id,
|
||||||
cmd.access_token,
|
cmd.access_token,
|
||||||
cmd.ip,
|
cmd.ip,
|
||||||
@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
cmd.last_seen,
|
cmd.last_seen,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def subscribe_to_stream(self, stream_name, token):
|
||||||
def subscribe_to_stream(self, stream_name, token):
|
|
||||||
"""Subscribe the remote to a stream.
|
"""Subscribe the remote to a stream.
|
||||||
|
|
||||||
This invloves checking if they've missed anything and sending those
|
This invloves checking if they've missed anything and sending those
|
||||||
@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get missing updates
|
# Get missing updates
|
||||||
updates, current_token = yield self.streamer.get_stream_updates(
|
updates, current_token = await self.streamer.get_stream_updates(
|
||||||
stream_name, token
|
stream_name, token
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|||||||
token (int): stream token for this batch of rows
|
token (int): stream token for this batch of rows
|
||||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||||
Stream.parse_row.
|
Stream.parse_row.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred|None
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def on_position(self, stream_name, token):
|
async def on_position(self, stream_name, token):
|
||||||
"""Called when we get new position data."""
|
"""Called when we get new position data."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
if not self.streams_connecting:
|
if not self.streams_connecting:
|
||||||
self.handler.finished_connecting()
|
self.handler.finished_connecting()
|
||||||
|
|
||||||
def on_SERVER(self, cmd):
|
async def on_SERVER(self, cmd):
|
||||||
if cmd.data != self.server_name:
|
if cmd.data != self.server_name:
|
||||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||||
self.send_error("Wrong remote")
|
self.send_error("Wrong remote")
|
||||||
|
|
||||||
def on_RDATA(self, cmd):
|
async def on_RDATA(self, cmd):
|
||||||
stream_name = cmd.stream_name
|
stream_name = cmd.stream_name
|
||||||
inbound_rdata_count.labels(stream_name).inc()
|
inbound_rdata_count.labels(stream_name).inc()
|
||||||
|
|
||||||
@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
# Check if this is the last of a batch of updates
|
# Check if this is the last of a batch of updates
|
||||||
rows = self.pending_batches.pop(stream_name, [])
|
rows = self.pending_batches.pop(stream_name, [])
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
return self.handler.on_rdata(stream_name, cmd.token, rows)
|
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||||
|
|
||||||
def on_POSITION(self, cmd):
|
async def on_POSITION(self, cmd):
|
||||||
# When we get a `POSITION` command it means we've finished getting
|
# When we get a `POSITION` command it means we've finished getting
|
||||||
# missing updates for the given stream, and are now up to date.
|
# missing updates for the given stream, and are now up to date.
|
||||||
self.streams_connecting.discard(cmd.stream_name)
|
self.streams_connecting.discard(cmd.stream_name)
|
||||||
if not self.streams_connecting:
|
if not self.streams_connecting:
|
||||||
self.handler.finished_connecting()
|
self.handler.finished_connecting()
|
||||||
|
|
||||||
return self.handler.on_position(cmd.stream_name, cmd.token)
|
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||||
|
|
||||||
def on_SYNC(self, cmd):
|
async def on_SYNC(self, cmd):
|
||||||
return self.handler.on_sync(cmd.data)
|
self.handler.on_sync(cmd.data)
|
||||||
|
|
||||||
def replicate(self, stream_name, token):
|
def replicate(self, stream_name, token):
|
||||||
"""Send the subscription request to the server
|
"""Send the subscription request to the server
|
||||||
|
@ -23,7 +23,6 @@ from six import itervalues
|
|||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
|
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
@ -155,8 +154,7 @@ class ReplicationStreamer(object):
|
|||||||
|
|
||||||
run_as_background_process("replication_notifier", self._run_notifier_loop)
|
run_as_background_process("replication_notifier", self._run_notifier_loop)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _run_notifier_loop(self):
|
||||||
def _run_notifier_loop(self):
|
|
||||||
self.is_looping = True
|
self.is_looping = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -185,7 +183,7 @@ class ReplicationStreamer(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if self._replication_torture_level:
|
if self._replication_torture_level:
|
||||||
yield self.clock.sleep(
|
await self.clock.sleep(
|
||||||
self._replication_torture_level / 1000.0
|
self._replication_torture_level / 1000.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,7 +194,7 @@ class ReplicationStreamer(object):
|
|||||||
stream.upto_token,
|
stream.upto_token,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
updates, current_token = yield stream.get_updates()
|
updates, current_token = await stream.get_updates()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("Failed to handle stream %s", stream.NAME)
|
logger.info("Failed to handle stream %s", stream.NAME)
|
||||||
raise
|
raise
|
||||||
@ -233,7 +231,7 @@ class ReplicationStreamer(object):
|
|||||||
self.is_looping = False
|
self.is_looping = False
|
||||||
|
|
||||||
@measure_func("repl.get_stream_updates")
|
@measure_func("repl.get_stream_updates")
|
||||||
def get_stream_updates(self, stream_name, token):
|
async def get_stream_updates(self, stream_name, token):
|
||||||
"""For a given stream get all updates since token. This is called when
|
"""For a given stream get all updates since token. This is called when
|
||||||
a client first subscribes to a stream.
|
a client first subscribes to a stream.
|
||||||
"""
|
"""
|
||||||
@ -241,7 +239,7 @@ class ReplicationStreamer(object):
|
|||||||
if not stream:
|
if not stream:
|
||||||
raise Exception("unknown stream %s", stream_name)
|
raise Exception("unknown stream %s", stream_name)
|
||||||
|
|
||||||
return stream.get_updates_since(token)
|
return await stream.get_updates_since(token)
|
||||||
|
|
||||||
@measure_func("repl.federation_ack")
|
@measure_func("repl.federation_ack")
|
||||||
def federation_ack(self, token):
|
def federation_ack(self, token):
|
||||||
@ -252,22 +250,20 @@ class ReplicationStreamer(object):
|
|||||||
self.federation_sender.federation_ack(token)
|
self.federation_sender.federation_ack(token)
|
||||||
|
|
||||||
@measure_func("repl.on_user_sync")
|
@measure_func("repl.on_user_sync")
|
||||||
@defer.inlineCallbacks
|
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||||
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
|
||||||
"""A client has started/stopped syncing on a worker.
|
"""A client has started/stopped syncing on a worker.
|
||||||
"""
|
"""
|
||||||
user_sync_counter.inc()
|
user_sync_counter.inc()
|
||||||
yield self.presence_handler.update_external_syncs_row(
|
await self.presence_handler.update_external_syncs_row(
|
||||||
conn_id, user_id, is_syncing, last_sync_ms
|
conn_id, user_id, is_syncing, last_sync_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
@measure_func("repl.on_remove_pusher")
|
@measure_func("repl.on_remove_pusher")
|
||||||
@defer.inlineCallbacks
|
async def on_remove_pusher(self, app_id, push_key, user_id):
|
||||||
def on_remove_pusher(self, app_id, push_key, user_id):
|
|
||||||
"""A client has asked us to remove a pusher
|
"""A client has asked us to remove a pusher
|
||||||
"""
|
"""
|
||||||
remove_pusher_counter.inc()
|
remove_pusher_counter.inc()
|
||||||
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
|
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||||
app_id=app_id, pushkey=push_key, user_id=user_id
|
app_id=app_id, pushkey=push_key, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -281,15 +277,16 @@ class ReplicationStreamer(object):
|
|||||||
getattr(self.store, cache_func).invalidate(tuple(keys))
|
getattr(self.store, cache_func).invalidate(tuple(keys))
|
||||||
|
|
||||||
@measure_func("repl.on_user_ip")
|
@measure_func("repl.on_user_ip")
|
||||||
@defer.inlineCallbacks
|
async def on_user_ip(
|
||||||
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
self, user_id, access_token, ip, user_agent, device_id, last_seen
|
||||||
|
):
|
||||||
"""The client saw a user request
|
"""The client saw a user request
|
||||||
"""
|
"""
|
||||||
user_ip_cache_counter.inc()
|
user_ip_cache_counter.inc()
|
||||||
yield self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id, access_token, ip, user_agent, device_id, last_seen
|
user_id, access_token, ip, user_agent, device_id, last_seen
|
||||||
)
|
)
|
||||||
yield self._server_notices_sender.on_user_ip(user_id)
|
await self._server_notices_sender.on_user_ip(user_id)
|
||||||
|
|
||||||
def send_sync_to_all_connections(self, data):
|
def send_sync_to_all_connections(self, data):
|
||||||
"""Sends a SYNC command to all clients.
|
"""Sends a SYNC command to all clients.
|
||||||
|
@ -19,8 +19,6 @@ import logging
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -144,8 +142,7 @@ class Stream(object):
|
|||||||
self.upto_token = self.current_token()
|
self.upto_token = self.current_token()
|
||||||
self.last_token = self.upto_token
|
self.last_token = self.upto_token
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_updates(self):
|
||||||
def get_updates(self):
|
|
||||||
"""Gets all updates since the last time this function was called (or
|
"""Gets all updates since the last time this function was called (or
|
||||||
since the stream was constructed if it hadn't been called before),
|
since the stream was constructed if it hadn't been called before),
|
||||||
until the `upto_token`
|
until the `upto_token`
|
||||||
@ -156,13 +153,12 @@ class Stream(object):
|
|||||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||||
sent over the replication steam.
|
sent over the replication steam.
|
||||||
"""
|
"""
|
||||||
updates, current_token = yield self.get_updates_since(self.last_token)
|
updates, current_token = await self.get_updates_since(self.last_token)
|
||||||
self.last_token = current_token
|
self.last_token = current_token
|
||||||
|
|
||||||
return updates, current_token
|
return updates, current_token
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_updates_since(self, from_token):
|
||||||
def get_updates_since(self, from_token):
|
|
||||||
"""Like get_updates except allows specifying from when we should
|
"""Like get_updates except allows specifying from when we should
|
||||||
stream updates
|
stream updates
|
||||||
|
|
||||||
@ -182,15 +178,16 @@ class Stream(object):
|
|||||||
if from_token == current_token:
|
if from_token == current_token:
|
||||||
return [], current_token
|
return [], current_token
|
||||||
|
|
||||||
|
logger.info("get_updates_since: %s", self.__class__)
|
||||||
if self._LIMITED:
|
if self._LIMITED:
|
||||||
rows = yield self.update_function(
|
rows = await self.update_function(
|
||||||
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
||||||
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
||||||
else:
|
else:
|
||||||
rows = yield self.update_function(from_token, current_token)
|
rows = await self.update_function(from_token, current_token)
|
||||||
|
|
||||||
updates = [(row[0], row[1:]) for row in rows]
|
updates = [(row[0], row[1:]) for row in rows]
|
||||||
|
|
||||||
@ -295,9 +292,8 @@ class PushRulesStream(Stream):
|
|||||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||||
return push_rules_token
|
return push_rules_token
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_function(self, from_token, to_token, limit):
|
||||||
def update_function(self, from_token, to_token, limit):
|
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||||
rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
|
||||||
return [(row[0], row[2]) for row in rows]
|
return [(row[0], row[2]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
@ -413,9 +409,8 @@ class AccountDataStream(Stream):
|
|||||||
|
|
||||||
super(AccountDataStream, self).__init__(hs)
|
super(AccountDataStream, self).__init__(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_function(self, from_token, to_token, limit):
|
||||||
def update_function(self, from_token, to_token, limit):
|
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||||
global_results, room_results = yield self.store.get_all_updated_account_data(
|
|
||||||
from_token, from_token, to_token, limit
|
from_token, from_token, to_token, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,8 +19,6 @@ from typing import Tuple, Type
|
|||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from ._base import Stream
|
from ._base import Stream
|
||||||
|
|
||||||
|
|
||||||
@ -122,16 +120,15 @@ class EventsStream(Stream):
|
|||||||
|
|
||||||
super(EventsStream, self).__init__(hs)
|
super(EventsStream, self).__init__(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_function(self, from_token, current_token, limit=None):
|
||||||
def update_function(self, from_token, current_token, limit=None):
|
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||||
event_rows = yield self._store.get_all_new_forward_event_rows(
|
|
||||||
from_token, current_token, limit
|
from_token, current_token, limit
|
||||||
)
|
)
|
||||||
event_updates = (
|
event_updates = (
|
||||||
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
|
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
|
||||||
)
|
)
|
||||||
|
|
||||||
state_rows = yield self._store.get_all_updated_current_state_deltas(
|
state_rows = await self._store.get_all_updated_current_state_deltas(
|
||||||
from_token, current_token, limit
|
from_token, current_token, limit
|
||||||
)
|
)
|
||||||
state_updates = (
|
state_updates = (
|
||||||
|
@ -73,6 +73,6 @@ class TestReplicationClientHandler(object):
|
|||||||
def finished_connecting(self):
|
def finished_connecting(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
Loading…
Reference in New Issue
Block a user