Add missing type hints to synapse.replication. (#11938)

This commit is contained in:
Patrick Cloke 2022-02-08 11:03:08 -05:00 committed by GitHub
parent 8c94b3abe9
commit d0e78af35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 209 additions and 147 deletions

View file

@ -18,12 +18,15 @@ allowed to be sent by which side.
"""
import abc
import logging
from typing import Tuple, Type
from typing import Optional, Tuple, Type, TypeVar
from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Command")
class Command(metaclass=abc.ABCMeta):
"""The base command class.
@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def from_line(cls, line):
def from_line(cls: Type[T], line: str) -> T:
"""Deserialises a line from the wire into this command. `line` does not
include the command.
"""
@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
prefix.
"""
def get_logcontext_id(self):
def get_logcontext_id(self) -> str:
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
SC = TypeVar("SC", bound="_SimpleCommand")
class _SimpleCommand(Command):
"""An implementation of Command whose argument is just a 'data' string."""
def __init__(self, data):
def __init__(self, data: str):
self.data = data
@classmethod
def from_line(cls, line):
def from_line(cls: Type[SC], line: str) -> SC:
return cls(line)
def to_line(self) -> str:
@ -109,14 +115,16 @@ class RdataCommand(Command):
NAME = "RDATA"
def __init__(self, stream_name, instance_name, token, row):
def __init__(
self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
stream_name,
@ -125,7 +133,7 @@ class RdataCommand(Command):
json_decoder.decode(row_json),
)
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@ -135,7 +143,7 @@ class RdataCommand(Command):
)
)
def get_logcontext_id(self):
def get_logcontext_id(self) -> str:
return "RDATA-" + self.stream_name
@ -164,18 +172,20 @@ class PositionCommand(Command):
NAME = "POSITION"
def __init__(self, stream_name, instance_name, prev_token, new_token):
def __init__(
self, stream_name: str, instance_name: str, prev_token: int, new_token: int
):
self.stream_name = stream_name
self.instance_name = instance_name
self.prev_token = prev_token
self.new_token = new_token
@classmethod
def from_line(cls, line):
def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@ -218,14 +228,14 @@ class ReplicateCommand(Command):
NAME = "REPLICATE"
def __init__(self):
def __init__(self) -> None:
pass
@classmethod
def from_line(cls, line):
def from_line(cls: Type[T], line: str) -> T:
return cls()
def to_line(self):
def to_line(self) -> str:
return ""
@ -247,14 +257,16 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
def __init__(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
@ -262,7 +274,7 @@ class UserSyncCommand(Command):
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.instance_id,
@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id):
def __init__(self, instance_id: str):
self.instance_id = instance_id
@classmethod
def from_line(cls, line):
def from_line(
cls: Type["ClearUserSyncsCommand"], line: str
) -> "ClearUserSyncsCommand":
return cls(line)
def to_line(self):
def to_line(self) -> str:
return self.instance_id
@ -316,7 +330,9 @@ class FederationAckCommand(Command):
self.token = token
@classmethod
def from_line(cls, line: str) -> "FederationAckCommand":
def from_line(
cls: Type["FederationAckCommand"], line: str
) -> "FederationAckCommand":
instance_name, token = line.split(" ")
return cls(instance_name, int(token))
@ -334,7 +350,15 @@ class UserIpCommand(Command):
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
def __init__(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
self.user_id = user_id
self.access_token = access_token
self.ip = ip
@ -343,14 +367,14 @@ class UserIpCommand(Command):
self.last_seen = last_seen
@classmethod
def from_line(cls, line):
def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self):
def to_line(self) -> str:
return (
self.user_id
+ " "