Add missing type hints to tests.replication. (#14987)

This commit is contained in:
Patrick Cloke 2023-02-06 09:55:00 -05:00 committed by GitHub
parent b275763c65
commit 156cd88eef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 193 additions and 149 deletions

View file

@ -18,12 +18,14 @@ from typing import Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
def test_basic(self):
def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
def test_download_simple_file_race(self):
def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self):
def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
def get_connection_factory():
def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
@ -263,6 +265,6 @@ def _build_test_server(
return server_tls_factory.buildProtocol(None)
def _log_request(request):
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)