diff --git a/changelog.d/11938.misc b/changelog.d/11938.misc new file mode 100644 index 000000000..1d3a0030f --- /dev/null +++ b/changelog.d/11938.misc @@ -0,0 +1 @@ +Add missing type hints to replication code. diff --git a/mypy.ini b/mypy.ini index 2884078d0..cd28ac0dd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -169,6 +169,9 @@ disallow_untyped_defs = True [mypy-synapse.push.*] disallow_untyped_defs = True +[mypy-synapse.replication.*] +disallow_untyped_defs = True + [mypy-synapse.rest.*] disallow_untyped_defs = True diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index fa132d10b..8f3f953ed 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker): for table, column in extra_tables: self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name: Optional[str], new_id: int): + def advance(self, instance_name: Optional[str], new_id: int) -> None: self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index bc888ce1a..b5b84c09a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore): cache_name="client_ip_last_seen", max_size=50000 ) - async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip( + self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str + ) -> None: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index a2aff75b7..0ffd34f1d 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) @@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) - def _invalidate_caches_for_devices(self, token, rows): + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: for row in rows: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 9d90e2637..d6f37d747 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): self._group_updates_id_gen.get_current_token(), ) - def get_group_stream_token(self): + def get_group_stream_token(self) -> int: return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == GroupServerStream.NAME: self._group_updates_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 7541e21de..52ee3f7e5 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Iterable from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -20,10 +21,12 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index cea90c0f1..de642bba7 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool, LoggingDatabaseConnection @@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): return self._pushers_id_gen.get_current_token() def process_replication_rows( - self, stream_name: str, instance_name: str, token, rows + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) # type: ignore + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e29ae1e37..d59ce7ccf 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,10 +14,12 @@ """A replication client for use by synapse workers. """ import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory +from twisted.python.failure import Failure from synapse.api.constants import EventTypes from synapse.federation import send_queue @@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory): hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) - def startedConnecting(self, connector): + def startedConnecting(self, connector: IConnector) -> None: logger.info("Connecting to replication: %r", connector.getDestination()) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol: logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( self.hs, @@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory): self.command_handler, ) - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.error("Lost replication conn: %r", reason) ReconnectingClientFactory.clientConnectionLost(self, connector, reason) - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.error("Failed to connect to replication: %r", reason) ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) @@ -131,7 +133,7 @@ class ReplicationDataHandler: async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -252,14 +254,16 @@ class ReplicationDataHandler: # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] - async def on_position(self, stream_name: str, instance_name: str, token: int): + async def on_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: await self.on_rdata(stream_name, instance_name, token, []) # We poke the generic "replication" notifier to wake anything up that # may be streaming. self.notifier.notify_replication() - def on_remote_server_up(self, server: str): + def on_remote_server_up(self, server: str) -> None: """Called when get a new REMOTE_SERVER_UP command.""" # Let's wake up the transaction queue for the server in case we have @@ -269,7 +273,7 @@ class ReplicationDataHandler: async def wait_for_stream_position( self, instance_name: str, stream_name: str, position: int - ): + ) -> None: """Wait until this instance has received updates up to and including the given stream position. """ @@ -304,7 +308,7 @@ class ReplicationDataHandler: "Finished waiting for repl stream %r to reach %s", stream_name, position ) - def stop_pusher(self, user_id, app_id, pushkey): + def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return @@ -316,13 +320,13 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id, app_id, pushkey): + async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) class FederationSenderHandler: @@ -353,10 +357,12 @@ class FederationSenderHandler: self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - def wake_destination(self, server: str): + def wake_destination(self, server: str) -> None: self.federation_sender.wake_destination(server) - async def process_replication_rows(self, stream_name, token, rows): + async def process_replication_rows( + self, stream_name: str, token: int, rows: list + ) -> None: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": @@ -384,11 +390,12 @@ class FederationSenderHandler: for host in hosts: self.federation_sender.send_device_messages(host) - async def _on_new_receipts(self, rows): + async def _on_new_receipts( + self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] + ) -> None: """ Args: - rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): - new receipts to be processed + rows: new receipts to be processed """ for receipt in rows: # we only want to send on receipts for our own users @@ -408,7 +415,7 @@ class FederationSenderHandler: ) await self.federation_sender.send_read_receipt(receipt_info) - async def update_token(self, token): + async def update_token(self, token: int) -> None: """Update the record of where we have processed to in the federation stream. Called after we have processed a an update received over replication. Sends @@ -428,7 +435,7 @@ class FederationSenderHandler: run_as_background_process("_save_and_send_ack", self._save_and_send_ack) - async def _save_and_send_ack(self): + async def _save_and_send_ack(self) -> None: """Save the current federation position in the database and send an ACK to master with where we're up to. """ diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 1311b013d..3654f6c03 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,12 +18,15 @@ allowed to be sent by which side. """ import abc import logging -from typing import Tuple, Type +from typing import Optional, Tuple, Type, TypeVar +from synapse.replication.tcp.streams._base import StreamRow from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) +T = TypeVar("T", bound="Command") + class Command(metaclass=abc.ABCMeta): """The base command class. @@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: """Deserialises a line from the wire into this command. `line` does not include the command. """ @@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta): prefix. """ - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: """Get a suitable string for the logcontext when processing this command""" # by default, we just use the command name. return self.NAME +SC = TypeVar("SC", bound="_SimpleCommand") + + class _SimpleCommand(Command): """An implementation of Command whose argument is just a 'data' string.""" - def __init__(self, data): + def __init__(self, data: str): self.data = data @classmethod - def from_line(cls, line): + def from_line(cls: Type[SC], line: str) -> SC: return cls(line) def to_line(self) -> str: @@ -109,14 +115,16 @@ class RdataCommand(Command): NAME = "RDATA" - def __init__(self, stream_name, instance_name, token, row): + def __init__( + self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow + ): self.stream_name = stream_name self.instance_name = instance_name self.token = token self.row = row @classmethod - def from_line(cls, line): + def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand": stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( stream_name, @@ -125,7 +133,7 @@ class RdataCommand(Command): json_decoder.decode(row_json), ) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -135,7 +143,7 @@ class RdataCommand(Command): ) ) - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: return "RDATA-" + self.stream_name @@ -164,18 +172,20 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, prev_token, new_token): + def __init__( + self, stream_name: str, instance_name: str, prev_token: int, new_token: int + ): self.stream_name = stream_name self.instance_name = instance_name self.prev_token = prev_token self.new_token = new_token @classmethod - def from_line(cls, line): + def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand": stream_name, instance_name, prev_token, new_token = line.split(" ", 3) return cls(stream_name, instance_name, int(prev_token), int(new_token)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -218,14 +228,14 @@ class ReplicateCommand(Command): NAME = "REPLICATE" - def __init__(self): + def __init__(self) -> None: pass @classmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: return cls() - def to_line(self): + def to_line(self) -> str: return "" @@ -247,14 +257,16 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" - def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + def __init__( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): @@ -262,7 +274,7 @@ class UserSyncCommand(Command): return cls(instance_id, user_id, state == "start", int(last_sync_ms)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.instance_id, @@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command): NAME = "CLEAR_USER_SYNC" - def __init__(self, instance_id): + def __init__(self, instance_id: str): self.instance_id = instance_id @classmethod - def from_line(cls, line): + def from_line( + cls: Type["ClearUserSyncsCommand"], line: str + ) -> "ClearUserSyncsCommand": return cls(line) - def to_line(self): + def to_line(self) -> str: return self.instance_id @@ -316,7 +330,9 @@ class FederationAckCommand(Command): self.token = token @classmethod - def from_line(cls, line: str) -> "FederationAckCommand": + def from_line( + cls: Type["FederationAckCommand"], line: str + ) -> "FederationAckCommand": instance_name, token = line.split(" ") return cls(instance_name, int(token)) @@ -334,7 +350,15 @@ class UserIpCommand(Command): NAME = "USER_IP" - def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): + def __init__( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): self.user_id = user_id self.access_token = access_token self.ip = ip @@ -343,14 +367,14 @@ class UserIpCommand(Command): self.last_seen = last_seen @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand": user_id, jsn = line.split(" ", 1) access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) - def to_line(self): + def to_line(self) -> str: return ( self.user_id + " " diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index f7e6bc1e6..17e157239 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -261,7 +261,7 @@ class ReplicationCommandHandler: "process-replication-data", self._unsafe_process_queue, stream_name ) - async def _unsafe_process_queue(self, stream_name: str): + async def _unsafe_process_queue(self, stream_name: str) -> None: """Processes the command queue for the given stream, until it is empty Does not check if there is already a thread processing the queue, hence "unsafe" @@ -294,7 +294,7 @@ class ReplicationCommandHandler: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs: "HomeServer"): + def start_replication(self, hs: "HomeServer") -> None: """Helper method to start a replication connection to the remote server using TCP. """ @@ -345,10 +345,10 @@ class ReplicationCommandHandler: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None: self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: IReplicationConnection): + def send_positions_to_connection(self, conn: IReplicationConnection) -> None: """Send current position of all streams this process is source of to the connection. """ @@ -392,7 +392,7 @@ class ReplicationCommandHandler: def on_FEDERATION_ACK( self, conn: IReplicationConnection, cmd: FederationAckCommand - ): + ) -> None: federation_ack_counter.inc() if self._federation_sender: @@ -408,7 +408,7 @@ class ReplicationCommandHandler: else: return None - async def _handle_user_ip(self, cmd: UserIpCommand): + async def _handle_user_ip(self, cmd: UserIpCommand) -> None: await self._store.insert_client_ip( cmd.user_id, cmd.access_token, @@ -421,7 +421,7 @@ class ReplicationCommandHandler: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -497,7 +497,7 @@ class ReplicationCommandHandler: async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. Args: @@ -512,7 +512,7 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -581,7 +581,7 @@ class ReplicationCommandHandler: def on_REMOTE_SERVER_UP( self, conn: IReplicationConnection, cmd: RemoteServerUpCommand - ): + ) -> None: """Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -604,7 +604,7 @@ class ReplicationCommandHandler: # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: IReplicationConnection): + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -631,7 +631,7 @@ class ReplicationCommandHandler: UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: IReplicationConnection): + def lost_connection(self, connection: IReplicationConnection) -> None: """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -653,7 +653,7 @@ class ReplicationCommandHandler: def send_command( self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None - ): + ) -> None: """Send a command to all connected connections. Args: @@ -680,7 +680,7 @@ class ReplicationCommandHandler: else: logger.warning("Dropping command as not connected: %r", cmd.NAME) - def send_federation_ack(self, token: int): + def send_federation_ack(self, token: int) -> None: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ @@ -688,7 +688,7 @@ class ReplicationCommandHandler: def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int - ): + ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) @@ -702,15 +702,15 @@ class ReplicationCommandHandler: user_agent: str, device_id: str, last_seen: int, - ): + ) -> None: """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) - def send_remote_server_up(self, server: str): + def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) - def stream_update(self, stream_name: str, token: str, data: Any): + def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7bae36db1..7763ffb2d 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,7 +49,7 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, Collection, List, Optional +from typing import TYPE_CHECKING, Any, Collection, List, Optional from prometheus_client import Counter from zope.interface import Interface, implementer @@ -123,7 +123,7 @@ class ConnectionStates: class IReplicationConnection(Interface): """An interface for replication connections.""" - def send_command(cmd: Command): + def send_command(cmd: Command) -> None: """Send the command down the connection""" @@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-conn", self.conn_id ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("[%s] Connection established", self.id()) self.state = ConnectionStates.ESTABLISHED @@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Always send the initial PING so that the other side knows that they # can time us out. - self.send_command(PingCommand(self.clock.time_msec())) + self.send_command(PingCommand(str(self.clock.time_msec()))) self.command_handler.new_connection(self) - def send_ping(self): + def send_ping(self) -> None: """Periodically sends a ping and checks if we should close the connection due to the other side timing out. """ @@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: - self.send_command(PingCommand(now)) + self.send_command(PingCommand(str(now))) if ( self.received_ping @@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line: bytes): + def lineReceived(self, line: bytes) -> None: """Called when we've received a line""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_line(line) - def _parse_and_dispatch_line(self, line: bytes): + def _parse_and_dispatch_line(self, line: bytes) -> None: if line.strip() == "": # Ignore blank lines return @@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if not handled: logger.warning("Unhandled command: %r", cmd) - def close(self): + def close(self) -> None: logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() - def send_error(self, error_string, *args): + def send_error(self, error_string: str, *args: Any) -> None: """Send an error to remote and close the connection.""" self.send_command(ErrorCommand(error_string % args)) self.close() - def send_command(self, cmd, do_buffer=True): + def send_command(self, cmd: Command, do_buffer: bool = True) -> None: """Send a command if connection has been established. Args: - cmd (Command) - do_buffer (bool): Whether to buffer the message or always attempt + cmd + do_buffer: Whether to buffer the message or always attempt to send the command. This is mostly used to send an error message if we're about to close the connection due our buffers becoming full. @@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_sent_command = self.clock.time_msec() - def _queue_command(self, cmd): + def _queue_command(self, cmd: Command) -> None: """Queue the command until the connection is ready to write to again.""" logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) @@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) self.close() - def _send_pending_commands(self): + def _send_pending_commands(self) -> None: """Send any queued commandes""" pending = self.pending_commands self.pending_commands = [] for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + def on_PING(self, cmd: PingCommand) -> None: self.received_ping = True - def on_ERROR(self, cmd): + def on_ERROR(self, cmd: ErrorCommand) -> None: logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) - def pauseProducing(self): + def pauseProducing(self) -> None: """This is called when both the kernel send buffer and the twisted tcp connection send buffers have become full. @@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info("[%s] Pause producing", self.id()) self.state = ConnectionStates.PAUSED - def resumeProducing(self): + def resumeProducing(self) -> None: """The remote has caught up after we started buffering!""" logger.info("[%s] Resume producing", self.id()) self.state = ConnectionStates.ESTABLISHED self._send_pending_commands() - def stopProducing(self): + def stopProducing(self) -> None: """We're never going to send any more data (normally because either we or the remote has closed the connection) """ logger.info("[%s] Stop producing", self.id()) self.on_connection_closed() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.labels(reason.__class__.__name__).inc() + connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable] try: # Remove us from list of connections to be monitored @@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.on_connection_closed() - def on_connection_closed(self): + def on_connection_closed(self) -> None: logger.info("[%s] Connection was closed", self.id()) self.state = ConnectionStates.CLOSED @@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def __str__(self): + def __str__(self) -> str: addr = None if self.transport: addr = str(self.transport.getPeer()) @@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): addr, ) - def id(self): + def id(self) -> str: return "%s-%s" % (self.name, self.conn_id) - def lineLengthExceeded(self, line): + def lineLengthExceeded(self, line: str) -> None: """Called when we receive a line that is above the maximum line length""" self.send_error("Line length exceeded") @@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(ServerCommand(self.server_name)) super().connectionMade() - def on_NAME(self, cmd): + def on_NAME(self, cmd: NameCommand) -> None: logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(NameCommand(self.client_name)) super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - def on_SERVER(self, cmd): + def on_SERVER(self, cmd: ServerCommand) -> None: if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def replicate(self): + def replicate(self) -> None: """Send the subscription request to the server""" logger.info("[%s] Subscribing to replication streams", self.id()) @@ -529,7 +529,7 @@ pending_commands = LaterGauge( ) -def transport_buffer_size(protocol): +def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int: if protocol.transport: size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen return size @@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge( ) -def transport_kernel_read_buffer_size(protocol, read=True): +def transport_kernel_read_buffer_size( + protocol: BaseReplicationStreamProtocol, read: bool = True +) -> int: SIOCINQ = 0x541B SIOCOUTQ = 0x5411 diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 5b37f379d..3170f7c59 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,7 +14,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast import attr import txredisapi @@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]): def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant - def __set__(self, obj: Optional[T], value: V): + def __set__(self, obj: Optional[T], value: V) -> None: pass @@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_stream_name: str synapse_outbound_redis_connection: txredisapi.RedisProtocol - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # a logcontext which we use for processing incoming commands. We declare it as a @@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): "replication_command_handler" ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() run_as_background_process("subscribe-replication", self._send_subscribe) - async def _send_subscribe(self): + async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. @@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # otherside won't know we've connected and so won't issue a REPLICATE. self.synapse_handler.send_positions_to_connection(self) - def messageReceived(self, pattern: str, channel: str, message: str): + def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) - def _parse_and_dispatch_message(self, message: str): + def _parse_and_dispatch_message(self, message: str) -> None: if message.strip() == "": # Ignore blank lines return @@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): "replication-" + cmd.get_logcontext_id(), lambda: res ) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("Lost connection to redis") super().connectionLost(reason) self.synapse_handler.lost_connection(self) @@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def send_command(self, cmd: Command): + def send_command(self, cmd: Command) -> None: """Send a command if connection has been established. Args: - cmd (Command) + cmd: The command to send """ run_as_background_process( "send-cmd", self._async_send_command, cmd, bg_start_span=False ) - async def _async_send_command(self, cmd: Command): + async def _async_send_command(self, cmd: Command) -> None: """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: @@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory): hs.get_clock().looping_call(self._send_ping, 30 * 1000) @wrap_as_background_process("redis_ping") - async def _send_ping(self): + async def _send_ping(self) -> None: for connection in self.pool: try: await make_deferred_yieldable(connection.ping()) @@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory): # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but # it's rubbish. We add our own here. - def startedConnecting(self, connector: IConnector): + def startedConnecting(self, connector: IConnector) -> None: logger.info( "Connecting to redis server %s", format_address(connector.getDestination()) ) super().startedConnecting(connector) - def clientConnectionFailed(self, connector: IConnector, reason: Failure): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s failed: %s", format_address(connector.getDestination()), @@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory): ) super().clientConnectionFailed(connector, reason) - def clientConnectionLost(self, connector: IConnector, reason: Failure): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s lost: %s", format_address(connector.getDestination()), @@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): self.synapse_outbound_redis_connection = outbound_redis_connection - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> RedisSubscriber: p = super().buildProtocol(addr) p = cast(RedisSubscriber, p) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index a9d85f4f6..ecd6190f5 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,16 +16,18 @@ import logging import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter +from twisted.internet.interfaces import IAddress from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream +from synapse.replication.tcp.streams._base import StreamRow, Token from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): # listener config again or always starting a `ReplicationStreamer`.) hs.get_replication_streamer() - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: return ServerReplicationStreamProtocol( self.server_name, self.clock, self.command_handler ) @@ -105,7 +107,7 @@ class ReplicationStreamer: if any(EventsStream.NAME == s.NAME for s in self.streams): self.clock.looping_call(self.on_notifier_poke, 1000) - def on_notifier_poke(self): + def on_notifier_poke(self) -> None: """Checks if there is actually any new data and sends it to the connections if there are. @@ -137,7 +139,7 @@ class ReplicationStreamer: run_as_background_process("replication_notifier", self._run_notifier_loop) - async def _run_notifier_loop(self): + async def _run_notifier_loop(self) -> None: self.is_looping = True try: @@ -238,7 +240,9 @@ class ReplicationStreamer: self.is_looping = False -def _batch_updates(updates): +def _batch_updates( + updates: List[Tuple[Token, StreamRow]] +) -> List[Tuple[Optional[Token], StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to implement batching. @@ -254,7 +258,7 @@ def _batch_updates(updates): if not updates: return [] - new_updates = [] + new_updates: List[Tuple[Optional[Token], StreamRow]] = [] for i, update in enumerate(updates[:-1]): if update[0] == updates[i + 1][0]: new_updates.append((None, update[1])) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 5a2d90c53..914b9eae8 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -90,7 +90,7 @@ class Stream: ROW_TYPE: Any = None @classmethod - def parse_row(cls, row: StreamRow): + def parse_row(cls, row: StreamRow) -> Any: """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -139,7 +139,7 @@ class Stream: # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) - def discard_updates_and_advance(self): + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ @@ -200,7 +200,7 @@ def current_token_without_instance( return lambda instance_name: current_token() -def make_http_update_function(hs, stream_name: str) -> UpdateFunction: +def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. """ diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 4f4f1ad45..50c4a5ba0 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast import attr -from ._base import Stream, StreamUpdateResult, Token +from synapse.replication.tcp.streams._base import ( + Stream, + StreamRow, + StreamUpdateResult, + Token, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -58,6 +62,9 @@ class EventsStreamRow: data: "BaseEventsStreamRow" +T = TypeVar("T", bound="BaseEventsStreamRow") + + class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. @@ -68,7 +75,7 @@ class BaseEventsStreamRow: TypeId: str @classmethod - def from_data(cls, data): + def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T: """Parse the data from the replication stream into a row. By default we just call the constructor with the data list as arguments @@ -221,7 +228,7 @@ class EventsStream(Stream): return updates, upper_limit, limited @classmethod - def parse_row(cls, row): - (typ, data) = row - data = TypeToRow[typ].from_data(data) - return EventsStreamRow(typ, data) + def parse_row(cls, row: StreamRow) -> "EventsStreamRow": + (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row) + event_stream_row_data = TypeToRow[typ].from_data(data) + return EventsStreamRow(typ, event_stream_row_data) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index ea1032b4f..b26546aec 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,8 +16,7 @@ import itertools import re import secrets import string -from collections.abc import Iterable -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from netaddr import valid_ipv6 @@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. - Otherwise, return the stringification of a a list with the first maxitems items, + Otherwise, return the stringification of a list with the first maxitems items, followed by "...". Args: diff --git a/tests/replication/_base.py b/tests/replication/_base.py index cb02eddf0..9fc50f885 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,6 +14,7 @@ import logging from typing import Any, Dict, List, Optional, Tuple +from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol from twisted.web.resource import Resource @@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( - None + IPv4Address("TCP", "127.0.0.1", 0) ) # Make a new HomeServer object for the worker @@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.clock, repl_handler, ) - server = self.server_factory.buildProtocol(None) + server = self.server_factory.buildProtocol( + IPv4Address("TCP", "127.0.0.1", 0) + ) client_transport = FakeTransport(server, self.reactor) client.makeConnection(client_transport) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py index 262c35cef..545f11acd 100644 --- a/tests/replication/tcp/test_remote_server_up.py +++ b/tests/replication/tcp/test_remote_server_up.py @@ -14,6 +14,7 @@ from typing import Tuple +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IProtocol from twisted.test.proto_helpers import StringTransport @@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase): def _make_client(self) -> Tuple[IProtocol, StringTransport]: """Create a new direct TCP replication connection""" - proto = self.factory.buildProtocol(("127.0.0.1", 0)) + proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0)) transport = StringTransport() proto.makeConnection(transport)