mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-04 18:48:47 -04:00
Add missing type hints to tests.replication. (#14987)
This commit is contained in:
parent
b275763c65
commit
156cd88eef
21 changed files with 193 additions and 149 deletions
|
@ -16,7 +16,9 @@ from collections import defaultdict
|
|||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.protocol import Protocol, connectionDone
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
|
@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
|
|||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
repl_handler,
|
||||
)
|
||||
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
self._client_transport: Optional[FakeTransport] = None
|
||||
self._server_transport: Optional[FakeTransport] = None
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
|
@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
def reconnect(self):
|
||||
def reconnect(self) -> None:
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
|
||||
|
@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||
self.server.makeConnection(self._server_transport)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
if self._client_transport:
|
||||
self._client_transport = None
|
||||
self.client.close()
|
||||
|
@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = None
|
||||
self.server.close()
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
|
@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
requests: List[SynapseRequest] = []
|
||||
real_request_factory = channel.requestFactory
|
||||
|
||||
def request_factory(*args, **kwargs):
|
||||
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
|
||||
request = real_request_factory(*args, **kwargs)
|
||||
requests.append(request)
|
||||
return request
|
||||
|
@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
):
|
||||
) -> None:
|
||||
"""Asserts that the given request is a HTTP replication request for
|
||||
fetching updates for given stream.
|
||||
"""
|
||||
|
@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
base["redis"] = {"enabled": True}
|
||||
return base
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
# build a replication server
|
||||
|
@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
lambda: self._handle_http_replication_attempt(self.hs, 8765),
|
||||
)
|
||||
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> ReplicationRestResource:
|
||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||
# We override this so that it automatically registers all the HTTP
|
||||
# replication servlets, without having to explicitly do that in all
|
||||
|
@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
return resource
|
||||
|
||||
def make_worker_hs(
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
|
||||
) -> HomeServer:
|
||||
"""Make a new worker HS instance, correctly connecting replcation
|
||||
stream to the master HS.
|
||||
|
@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def replicate(self):
|
||||
def replicate(self) -> None:
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump()
|
||||
|
||||
def _handle_http_replication_attempt(self, hs, repl_port):
|
||||
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
|
||||
"""Handles a connection attempt to the given HS replication HTTP
|
||||
listener on the given port.
|
||||
"""
|
||||
|
@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
# inside `connecTCP` before the connection has been passed back to the
|
||||
# code that requested the TCP connection.
|
||||
|
||||
def connect_any_redis_attempts(self):
|
||||
def connect_any_redis_attempts(self) -> None:
|
||||
"""If redis is enabled we need to deal with workers connecting to a
|
||||
redis server. We don't want to use a real Redis server so we use a
|
||||
fake one.
|
||||
|
@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
server_protocol = self._redis_server.buildProtocol(None)
|
||||
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
|
||||
client_protocol = client_factory.buildProtocol(client_address)
|
||||
|
||||
server_address = IPv4Address("TCP", host, port)
|
||||
server_protocol = self._redis_server.buildProtocol(server_address)
|
||||
|
||||
client_to_server_transport = FakeTransport(
|
||||
server_protocol, self.reactor, client_protocol
|
||||
|
@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
|
||||
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
) -> None:
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
class FakeRedisPubSubServer:
|
||||
"""A fake Redis server for pub/sub."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._subscribers_by_channel: Dict[
|
||||
bytes, Set["FakeRedisPubSubProtocol"]
|
||||
] = defaultdict(set)
|
||||
|
||||
def add_subscriber(self, conn, channel: bytes):
|
||||
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
|
||||
"""A connection has called SUBSCRIBE"""
|
||||
self._subscribers_by_channel[channel].add(conn)
|
||||
|
||||
def remove_subscriber(self, conn):
|
||||
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
|
||||
"""A connection has lost connection"""
|
||||
for subscribers in self._subscribers_by_channel.values():
|
||||
subscribers.discard(conn)
|
||||
|
||||
def publish(self, conn, channel: bytes, msg) -> int:
|
||||
def publish(
|
||||
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
|
||||
) -> int:
|
||||
"""A connection want to publish a message to subscribers."""
|
||||
for sub in self._subscribers_by_channel[channel]:
|
||||
sub.send(["message", channel, msg])
|
||||
|
||||
return len(self._subscribers_by_channel)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
|
||||
return FakeRedisPubSubProtocol(self)
|
||||
|
||||
|
||||
|
@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self._server = server
|
||||
self._reader = hiredis.Reader()
|
||||
|
||||
def dataReceived(self, data):
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self._reader.feed(data)
|
||||
|
||||
# We might get multiple messages in one packet.
|
||||
|
@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
self.handle_command(msg[0], *msg[1:])
|
||||
|
||||
def handle_command(self, command, *args):
|
||||
def handle_command(self, command: bytes, *args: bytes) -> None:
|
||||
"""Received a Redis command from the client."""
|
||||
|
||||
# We currently only support pub/sub.
|
||||
|
@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self.send("PONG")
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown command: {command}")
|
||||
raise Exception(f"Unknown command: {command!r}")
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg: object) -> None:
|
||||
"""Send a message back to the client."""
|
||||
assert self.transport is not None
|
||||
|
||||
|
@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
self.transport.write(raw)
|
||||
self.transport.flush()
|
||||
|
||||
def encode(self, obj):
|
||||
def encode(self, obj: object) -> str:
|
||||
"""Encode an object to its Redis format.
|
||||
|
||||
Supports: strings/bytes, integers and list/tuples.
|
||||
|
@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
self._server.remove_subscriber(self)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue