Add missing type hints to tests.replication. (#14987)

This commit is contained in:
Patrick Cloke 2023-02-06 09:55:00 -05:00 committed by GitHub
parent b275763c65
commit 156cd88eef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 193 additions and 149 deletions

View file

@ -16,7 +16,9 @@ from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler,
)
self._client_transport = None
self._server_transport = None
self._client_transport: Optional[FakeTransport] = None
self._server_transport: Optional[FakeTransport] = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
def reconnect(self) -> None:
if self._client_transport:
self.client.close()
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
def disconnect(self) -> None:
if self._client_transport:
self._client_transport = None
self.client.close()
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
self.server.close()
def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
def request_factory(*args, **kwargs):
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
base["redis"] = {"enabled": True}
return base
def setUp(self):
def setUp(self) -> None:
super().setUp()
# build a replication server
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_resource(self):
def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
def replicate(self):
def replicate(self) -> None:
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def _handle_http_replication_attempt(self, hs, repl_port):
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
def connect_any_redis_attempts(self):
def connect_any_redis_attempts(self) -> None:
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
client_protocol = client_factory.buildProtocol(client_address)
server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
async def on_rdata(self, stream_name, instance_name, token, rows):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
def __init__(self) -> None:
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
def add_subscriber(self, conn, channel: bytes):
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
def publish(self, conn, channel: bytes, msg) -> int:
def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
self._server = server
self._reader = hiredis.Reader()
def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)
# We might get multiple messages in one packet.
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""
# We currently only support pub/sub.
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("PONG")
else:
raise Exception(f"Unknown command: {command}")
raise Exception(f"Unknown command: {command!r}")
def send(self, msg):
def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.transport.write(raw)
self.transport.flush()
def encode(self, obj):
def encode(self, obj: object) -> str:
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
def connectionLost(self, reason):
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)