Add missing type hints to tests.config. (#14681)

This commit is contained in:
Patrick Cloke 2022-12-16 08:53:28 -05:00 committed by GitHub
parent 864c3f85b0
commit 3aeca2588b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 108 additions and 103 deletions

1
changelog.d/14681.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -36,8 +36,6 @@ exclude = (?x)
|tests/api/test_ratelimiting.py |tests/api/test_ratelimiting.py
|tests/app/test_openid_listener.py |tests/app/test_openid_listener.py
|tests/appservice/test_scheduler.py |tests/appservice/test_scheduler.py
|tests/config/test_cache.py
|tests/config/test_tls.py
|tests/crypto/test_keyring.py |tests/crypto/test_keyring.py
|tests/events/test_presence_router.py |tests/events/test_presence_router.py
|tests/events/test_utils.py |tests/events/test_utils.py
@ -89,7 +87,7 @@ disallow_untyped_defs = False
[mypy-tests.*] [mypy-tests.*]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.config.test_api] [mypy-tests.config.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.transport.test_client]

View File

@ -16,7 +16,7 @@ import logging
import os import os
import re import re
import threading import threading
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Mapping, Optional
import attr import attr
@ -94,7 +94,7 @@ def add_resizable_cache(
class CacheConfig(Config): class CacheConfig(Config):
section = "caches" section = "caches"
_environ = os.environ _environ: Mapping[str, str] = os.environ
event_cache_size: int event_cache_size: int
cache_factors: Dict[str, float] cache_factors: Dict[str, float]

View File

@ -788,26 +788,21 @@ class LruCache(Generic[KT, VT]):
def __contains__(self, key: KT) -> bool: def __contains__(self, key: KT) -> bool:
return self.contains(key) return self.contains(key)
def set_cache_factor(self, factor: float) -> bool: def set_cache_factor(self, factor: float) -> None:
""" """
Set the cache factor for this individual cache. Set the cache factor for this individual cache.
This will trigger a resize if it changes, which may require evicting This will trigger a resize if it changes, which may require evicting
items from the cache. items from the cache.
Returns:
Whether the cache changed size or not.
""" """
if not self.apply_cache_factor_from_config: if not self.apply_cache_factor_from_config:
return False return
new_size = int(self._original_max_size * factor) new_size = int(self._original_max_size * factor)
if new_size != self.max_size: if new_size != self.max_size:
self.max_size = new_size self.max_size = new_size
if self._on_resize: if self._on_resize:
self._on_resize() self._on_resize()
return True
return False
def __del__(self) -> None: def __del__(self) -> None:
# We're about to be deleted, so we make sure to clear up all the nodes # We're about to be deleted, so we make sure to clear up all the nodes

View File

@ -17,15 +17,15 @@ from tests.config.utils import ConfigFileTestCase
class ConfigMainFileTestCase(ConfigFileTestCase): class ConfigMainFileTestCase(ConfigFileTestCase):
def test_executes_without_an_action(self): def test_executes_without_an_action(self) -> None:
self.generate_config() self.generate_config()
main(["", "-c", self.config_file]) main(["", "-c", self.config_file])
def test_read__error_if_key_not_found(self): def test_read__error_if_key_not_found(self) -> None:
self.generate_config() self.generate_config()
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
main(["", "read", "foo.bar.hello", "-c", self.config_file]) main(["", "read", "foo.bar.hello", "-c", self.config_file])
def test_read__passes_if_key_found(self): def test_read__passes_if_key_found(self) -> None:
self.generate_config() self.generate_config()
main(["", "read", "server.server_name", "-c", self.config_file]) main(["", "read", "server.server_name", "-c", self.config_file])

View File

@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
# Tests that the default values in the config are correctly loaded. Note that the default # Tests that the default values in the config are correctly loaded. Note that the default
# values are loaded when the corresponding config options are commented out, which is why there isn't # values are loaded when the corresponding config options are commented out, which is why there isn't
# a config specified here. # a config specified here.
def test_default_configuration(self): def test_default_configuration(self) -> None:
background_updater = BackgroundUpdater( background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool self.hs, self.hs.get_datastores().main.db_pool
) )
@ -46,7 +46,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
""" """
) )
) )
def test_custom_configuration(self): def test_custom_configuration(self) -> None:
background_updater = BackgroundUpdater( background_updater = BackgroundUpdater(
self.hs, self.hs.get_datastores().main.db_pool self.hs, self.hs.get_datastores().main.db_pool
) )

View File

@ -24,13 +24,13 @@ from tests import unittest
class BaseConfigTestCase(unittest.TestCase): class BaseConfigTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
# The root object needs a server property with a public_baseurl. # The root object needs a server property with a public_baseurl.
root = Mock() root = Mock()
root.server.public_baseurl = "http://test" root.server.public_baseurl = "http://test"
self.config = Config(root) self.config = Config(root)
def test_loading_missing_templates(self): def test_loading_missing_templates(self) -> None:
# Use a temporary directory that exists on the system, but that isn't likely to # Use a temporary directory that exists on the system, but that isn't likely to
# contain template files # contain template files
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
@ -50,7 +50,7 @@ class BaseConfigTestCase(unittest.TestCase):
"Template file did not contain our test string", "Template file did not contain our test string",
) )
def test_loading_custom_templates(self): def test_loading_custom_templates(self) -> None:
# Use a temporary directory that exists on the system # Use a temporary directory that exists on the system
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# Create a temporary bogus template file # Create a temporary bogus template file
@ -79,7 +79,7 @@ class BaseConfigTestCase(unittest.TestCase):
"Template file did not contain our test string", "Template file did not contain our test string",
) )
def test_multiple_custom_template_directories(self): def test_multiple_custom_template_directories(self) -> None:
"""Tests that directories are searched in the right order if multiple custom """Tests that directories are searched in the right order if multiple custom
template directories are provided. template directories are provided.
""" """
@ -137,7 +137,7 @@ class BaseConfigTestCase(unittest.TestCase):
for td in tempdirs: for td in tempdirs:
td.cleanup() td.cleanup()
def test_loading_template_from_nonexistent_custom_directory(self): def test_loading_template_from_nonexistent_custom_directory(self) -> None:
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
self.config.read_templates( self.config.read_templates(
["some_filename.html"], ("a_nonexistent_directory",) ["some_filename.html"], ("a_nonexistent_directory",)

View File

@ -13,26 +13,27 @@
# limitations under the License. # limitations under the License.
from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from tests.unittest import TestCase from tests.unittest import TestCase
class CacheConfigTests(TestCase): class CacheConfigTests(TestCase):
def setUp(self): def setUp(self) -> None:
# Reset caches before each test since there's global state involved. # Reset caches before each test since there's global state involved.
self.config = CacheConfig() self.config = CacheConfig()
self.config.reset() self.config.reset()
def tearDown(self): def tearDown(self) -> None:
# Also reset the caches after each test to leave state pristine. # Also reset the caches after each test to leave state pristine.
self.config.reset() self.config.reset()
def test_individual_caches_from_environ(self): def test_individual_caches_from_environ(self) -> None:
""" """
Individual cache factors will be loaded from the environment. Individual cache factors will be loaded from the environment.
""" """
config = {} config: JsonDict = {}
self.config._environ = { self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_NOT_CACHE": "BLAH", "SYNAPSE_NOT_CACHE": "BLAH",
@ -42,15 +43,15 @@ class CacheConfigTests(TestCase):
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
def test_config_overrides_environ(self): def test_config_overrides_environ(self) -> None:
""" """
Individual cache factors defined in the environment will take precedence Individual cache factors defined in the environment will take precedence
over those in the config. over those in the config.
""" """
config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
self.config._environ = { self.config._environ = {
"SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
"SYNAPSE_CACHE_FACTOR_FOO": 1, "SYNAPSE_CACHE_FACTOR_FOO": "1",
} }
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
@ -60,104 +61,104 @@ class CacheConfigTests(TestCase):
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
) )
def test_individual_instantiated_before_config_load(self): def test_individual_instantiated_before_config_load(self) -> None:
""" """
If a cache is instantiated before the config is read, it will be given If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized once the config the default cache size in the interim, and then resized once the config
is loaded. is loaded.
""" """
cache = LruCache(100) cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50) self.assertEqual(cache.max_size, 50)
config = {"caches": {"per_cache_factors": {"foo": 3}}} config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}}
self.config.read_config(config) self.config.read_config(config)
self.config.resize_all_caches() self.config.resize_all_caches()
self.assertEqual(cache.max_size, 300) self.assertEqual(cache.max_size, 300)
def test_individual_instantiated_after_config_load(self): def test_individual_instantiated_after_config_load(self) -> None:
""" """
If a cache is instantiated after the config is read, it will be If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the per_cache_factor if immediately resized to the correct size given the per_cache_factor if
there is one. there is one.
""" """
config = {"caches": {"per_cache_factors": {"foo": 2}}} config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}}
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache = LruCache(100) cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200) self.assertEqual(cache.max_size, 200)
def test_global_instantiated_before_config_load(self): def test_global_instantiated_before_config_load(self) -> None:
""" """
If a cache is instantiated before the config is read, it will be given If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized to the new the default cache size in the interim, and then resized to the new
default cache size once the config is loaded. default cache size once the config is loaded.
""" """
cache = LruCache(100) cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50) self.assertEqual(cache.max_size, 50)
config = {"caches": {"global_factor": 4}} config: JsonDict = {"caches": {"global_factor": 4}}
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
self.assertEqual(cache.max_size, 400) self.assertEqual(cache.max_size, 400)
def test_global_instantiated_after_config_load(self): def test_global_instantiated_after_config_load(self) -> None:
""" """
If a cache is instantiated after the config is read, it will be If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the global factor if there immediately resized to the correct size given the global factor if there
is no per-cache factor. is no per-cache factor.
""" """
config = {"caches": {"global_factor": 1.5}} config: JsonDict = {"caches": {"global_factor": 1.5}}
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache = LruCache(100) cache: LruCache = LruCache(100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150) self.assertEqual(cache.max_size, 150)
def test_cache_with_asterisk_in_name(self): def test_cache_with_asterisk_in_name(self) -> None:
"""Some caches have asterisks in their name, test that they are set correctly.""" """Some caches have asterisks in their name, test that they are set correctly."""
config = { config: JsonDict = {
"caches": { "caches": {
"per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
} }
} }
self.config._environ = { self.config._environ = {
"SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
"SYNAPSE_CACHE_FACTOR_CACHE_B": 3, "SYNAPSE_CACHE_FACTOR_CACHE_B": "3",
} }
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache_a = LruCache(100) cache_a: LruCache = LruCache(100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
self.assertEqual(cache_a.max_size, 200) self.assertEqual(cache_a.max_size, 200)
cache_b = LruCache(100) cache_b: LruCache = LruCache(100)
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
self.assertEqual(cache_b.max_size, 300) self.assertEqual(cache_b.max_size, 300)
cache_c = LruCache(100) cache_c: LruCache = LruCache(100)
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200) self.assertEqual(cache_c.max_size, 200)
def test_apply_cache_factor_from_config(self): def test_apply_cache_factor_from_config(self) -> None:
"""Caches can disable applying cache factor updates, mainly used by """Caches can disable applying cache factor updates, mainly used by
event cache size. event cache size.
""" """
config = {"caches": {"event_cache_size": "10k"}} config: JsonDict = {"caches": {"event_cache_size": "10k"}}
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache = LruCache( cache: LruCache = LruCache(
max_size=self.config.event_cache_size, max_size=self.config.event_cache_size,
apply_cache_factor_from_config=False, apply_cache_factor_from_config=False,
) )

View File

@ -20,7 +20,7 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase): class DatabaseConfigTestCase(unittest.TestCase):
def test_database_configured_correctly(self): def test_database_configured_correctly(self) -> None:
conf = yaml.safe_load( conf = yaml.safe_load(
DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
) )

View File

@ -25,14 +25,14 @@ from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase): class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
self.file = os.path.join(self.dir, "homeserver.yaml") self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self): def tearDown(self) -> None:
shutil.rmtree(self.dir) shutil.rmtree(self.dir)
def test_generate_config_generates_files(self): def test_generate_config_generates_files(self) -> None:
with redirect_stdout(StringIO()): with redirect_stdout(StringIO()):
HomeServerConfig.load_or_generate_config( HomeServerConfig.load_or_generate_config(
"", "",
@ -56,7 +56,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
os.path.join(os.getcwd(), "homeserver.log"), os.path.join(os.getcwd(), "homeserver.log"),
) )
def assert_log_filename_is(self, log_config_file, expected): def assert_log_filename_is(self, log_config_file: str, expected: str) -> None:
with open(log_config_file) as f: with open(log_config_file) as f:
config = f.read() config = f.read()
# find the 'filename' line # find the 'filename' line

View File

@ -21,14 +21,14 @@ from tests.config.utils import ConfigFileTestCase
class ConfigLoadingFileTestCase(ConfigFileTestCase): class ConfigLoadingFileTestCase(ConfigFileTestCase):
def test_load_fails_if_server_name_missing(self): def test_load_fails_if_server_name_missing(self) -> None:
self.generate_config_and_remove_lines_containing("server_name") self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.config_file]) HomeServerConfig.load_config("", ["-c", self.config_file])
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
def test_generates_and_loads_macaroon_secret_key(self): def test_generates_and_loads_macaroon_secret_key(self) -> None:
self.generate_config() self.generate_config()
with open(self.config_file) as f: with open(self.config_file) as f:
@ -58,7 +58,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config2.key.macaroon_secret_key,) "was: %r" % (config2.key.macaroon_secret_key,)
) )
def test_load_succeeds_if_macaroon_secret_key_missing(self): def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None:
self.generate_config_and_remove_lines_containing("macaroon") self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
@ -73,7 +73,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1.key.macaroon_secret_key, config3.key.macaroon_secret_key config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
) )
def test_disable_registration(self): def test_disable_registration(self) -> None:
self.generate_config() self.generate_config()
self.add_lines_to_config( self.add_lines_to_config(
["enable_registration: true", "disable_registration: true"] ["enable_registration: true", "disable_registration: true"]
@ -93,7 +93,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
assert config3 is not None assert config3 is not None
self.assertTrue(config3.registration.enable_registration) self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self): def test_stats_enabled(self) -> None:
self.generate_config_and_remove_lines_containing("enable_metrics") self.generate_config_and_remove_lines_containing("enable_metrics")
self.add_lines_to_config(["enable_metrics: true"]) self.add_lines_to_config(["enable_metrics: true"])
@ -101,7 +101,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file]) config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.metrics.metrics_flags.known_servers) self.assertFalse(config.metrics.metrics_flags.known_servers)
def test_depreciated_identity_server_flag_throws_error(self): def test_depreciated_identity_server_flag_throws_error(self) -> None:
self.generate_config() self.generate_config()
# Needed to ensure that actual key/value pair added below don't end up on a line with a comment # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
self.add_lines_to_config([" "]) self.add_lines_to_config([" "])

View File

@ -18,7 +18,7 @@ from tests.utils import default_config
class RatelimitConfigTestCase(TestCase): class RatelimitConfigTestCase(TestCase):
def test_parse_rc_federation(self): def test_parse_rc_federation(self) -> None:
config_dict = default_config("test") config_dict = default_config("test")
config_dict["rc_federation"] = { config_dict["rc_federation"] = {
"window_size": 20000, "window_size": 20000,

View File

@ -21,7 +21,7 @@ from tests.utils import default_config
class RegistrationConfigTestCase(ConfigFileTestCase): class RegistrationConfigTestCase(ConfigFileTestCase):
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None:
""" """
session_lifetime should logically be larger than, or at least as large as, session_lifetime should logically be larger than, or at least as large as,
all the different token lifetimes. all the different token lifetimes.
@ -91,7 +91,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
"", "",
) )
def test_refuse_to_start_if_open_registration_and_no_verification(self): def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None:
self.generate_config() self.generate_config()
self.add_lines_to_config( self.add_lines_to_config(
[ [

View File

@ -20,7 +20,7 @@ from tests import unittest
class RoomDirectoryConfigTestCase(unittest.TestCase): class RoomDirectoryConfigTestCase(unittest.TestCase):
def test_alias_creation_acl(self): def test_alias_creation_acl(self) -> None:
config = yaml.safe_load( config = yaml.safe_load(
""" """
alias_creation_rules: alias_creation_rules:
@ -78,7 +78,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
) )
) )
def test_room_publish_acl(self): def test_room_publish_acl(self) -> None:
config = yaml.safe_load( config = yaml.safe_load(
""" """
alias_creation_rules: [] alias_creation_rules: []

View File

@ -21,7 +21,7 @@ from tests import unittest
class ServerConfigTestCase(unittest.TestCase): class ServerConfigTestCase(unittest.TestCase):
def test_is_threepid_reserved(self): def test_is_threepid_reserved(self) -> None:
user1 = {"medium": "email", "address": "user1@example.com"} user1 = {"medium": "email", "address": "user1@example.com"}
user2 = {"medium": "email", "address": "user2@example.com"} user2 = {"medium": "email", "address": "user2@example.com"}
user3 = {"medium": "email", "address": "user3@example.com"} user3 = {"medium": "email", "address": "user3@example.com"}
@ -32,7 +32,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertFalse(is_threepid_reserved(config, user3)) self.assertFalse(is_threepid_reserved(config, user3))
self.assertFalse(is_threepid_reserved(config, user1_msisdn)) self.assertFalse(is_threepid_reserved(config, user1_msisdn))
def test_unsecure_listener_no_listeners_open_private_ports_false(self): def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None:
conf = yaml.safe_load( conf = yaml.safe_load(
ServerConfig().generate_config_section( ServerConfig().generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", False, None "CONFDIR", "/data_dir_path", "che.org", False, None
@ -52,7 +52,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], expected_listeners) self.assertEqual(conf["listeners"], expected_listeners)
def test_unsecure_listener_no_listeners_open_private_ports_true(self): def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None:
conf = yaml.safe_load( conf = yaml.safe_load(
ServerConfig().generate_config_section( ServerConfig().generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", True, None "CONFDIR", "/data_dir_path", "che.org", True, None
@ -71,7 +71,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], expected_listeners) self.assertEqual(conf["listeners"], expected_listeners)
def test_listeners_set_correctly_open_private_ports_false(self): def test_listeners_set_correctly_open_private_ports_false(self) -> None:
listeners = [ listeners = [
{ {
"port": 8448, "port": 8448,
@ -95,7 +95,7 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertEqual(conf["listeners"], listeners) self.assertEqual(conf["listeners"], listeners)
def test_listeners_set_correctly_open_private_ports_true(self): def test_listeners_set_correctly_open_private_ports_true(self) -> None:
listeners = [ listeners = [
{ {
"port": 8448, "port": 8448,
@ -131,14 +131,14 @@ class ServerConfigTestCase(unittest.TestCase):
class GenerateIpSetTestCase(unittest.TestCase): class GenerateIpSetTestCase(unittest.TestCase):
def test_empty(self): def test_empty(self) -> None:
ip_set = generate_ip_set(()) ip_set = generate_ip_set(())
self.assertFalse(ip_set) self.assertFalse(ip_set)
ip_set = generate_ip_set((), ()) ip_set = generate_ip_set((), ())
self.assertFalse(ip_set) self.assertFalse(ip_set)
def test_generate(self): def test_generate(self) -> None:
"""Check adding IPv4 and IPv6 addresses.""" """Check adding IPv4 and IPv6 addresses."""
# IPv4 address # IPv4 address
ip_set = generate_ip_set(("1.2.3.4",)) ip_set = generate_ip_set(("1.2.3.4",))
@ -160,7 +160,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4")) ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
self.assertEqual(len(ip_set.iter_cidrs()), 4) self.assertEqual(len(ip_set.iter_cidrs()), 4)
def test_extra(self): def test_extra(self) -> None:
"""Extra IP addresses are treated the same.""" """Extra IP addresses are treated the same."""
ip_set = generate_ip_set((), ("1.2.3.4",)) ip_set = generate_ip_set((), ("1.2.3.4",))
self.assertEqual(len(ip_set.iter_cidrs()), 4) self.assertEqual(len(ip_set.iter_cidrs()), 4)
@ -172,7 +172,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",)) ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
self.assertEqual(len(ip_set.iter_cidrs()), 4) self.assertEqual(len(ip_set.iter_cidrs()), 4)
def test_bad_value(self): def test_bad_value(self) -> None:
"""An error should be raised if a bad value is passed in.""" """An error should be raised if a bad value is passed in."""
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
generate_ip_set(("not-an-ip",)) generate_ip_set(("not-an-ip",))

View File

@ -13,13 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import cast
import idna import idna
from OpenSSL import SSL from OpenSSL import SSL
from synapse.config._base import Config, RootConfig from synapse.config._base import Config, RootConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.tls import ConfigError, TlsConfig from synapse.config.tls import ConfigError, TlsConfig
from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.crypto.context_factory import (
FederationPolicyForHTTPS,
SSLClientConnectionCreator,
)
from synapse.types import JsonDict
from tests.unittest import TestCase from tests.unittest import TestCase
@ -27,7 +34,7 @@ from tests.unittest import TestCase
class FakeServer(Config): class FakeServer(Config):
section = "server" section = "server"
def has_tls_listener(self): def has_tls_listener(self) -> bool:
return False return False
@ -36,21 +43,21 @@ class TestConfig(RootConfig):
class TLSConfigTests(TestCase): class TLSConfigTests(TestCase):
def test_tls_client_minimum_default(self): def test_tls_client_minimum_default(self) -> None:
""" """
The default client TLS version is 1.0. The default client TLS version is 1.0.
""" """
config = {} config: JsonDict = {}
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
def test_tls_client_minimum_set(self): def test_tls_client_minimum_set(self) -> None:
""" """
The default client TLS version can be set to 1.0, 1.1, and 1.2. The default client TLS version can be set to 1.0, 1.1, and 1.2.
""" """
config = {"federation_client_minimum_tls_version": 1} config: JsonDict = {"federation_client_minimum_tls_version": 1}
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
@ -76,7 +83,7 @@ class TLSConfigTests(TestCase):
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
def test_tls_client_minimum_1_point_3_missing(self): def test_tls_client_minimum_1_point_3_missing(self) -> None:
""" """
If TLS 1.3 support is missing and it's configured, it will raise a If TLS 1.3 support is missing and it's configured, it will raise a
ConfigError. ConfigError.
@ -88,7 +95,7 @@ class TLSConfigTests(TestCase):
self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3) self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
assert not hasattr(SSL, "OP_NO_TLSv1_3") assert not hasattr(SSL, "OP_NO_TLSv1_3")
config = {"federation_client_minimum_tls_version": 1.3} config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig() t = TestConfig()
with self.assertRaises(ConfigError) as e: with self.assertRaises(ConfigError) as e:
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
@ -100,7 +107,7 @@ class TLSConfigTests(TestCase):
), ),
) )
def test_tls_client_minimum_1_point_3_exists(self): def test_tls_client_minimum_1_point_3_exists(self) -> None:
""" """
If TLS 1.3 support exists and it's configured, it will be settable. If TLS 1.3 support exists and it's configured, it will be settable.
""" """
@ -110,20 +117,20 @@ class TLSConfigTests(TestCase):
self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3")) self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
assert hasattr(SSL, "OP_NO_TLSv1_3") assert hasattr(SSL, "OP_NO_TLSv1_3")
config = {"federation_client_minimum_tls_version": 1.3} config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
def test_tls_client_minimum_set_passed_through_1_2(self): def test_tls_client_minimum_set_passed_through_1_2(self) -> None:
""" """
The configured TLS version is correctly configured by the ContextFactory. The configured TLS version is correctly configured by the ContextFactory.
""" """
config = {"federation_client_minimum_tls_version": 1.2} config: JsonDict = {"federation_client_minimum_tls_version": 1.2}
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t) cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
options = _get_ssl_context_options(cf._verify_ssl_context) options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
@ -131,15 +138,15 @@ class TLSConfigTests(TestCase):
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_tls_client_minimum_set_passed_through_1_0(self): def test_tls_client_minimum_set_passed_through_1_0(self) -> None:
""" """
The configured TLS version is correctly configured by the ContextFactory. The configured TLS version is correctly configured by the ContextFactory.
""" """
config = {"federation_client_minimum_tls_version": 1} config: JsonDict = {"federation_client_minimum_tls_version": 1}
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t) cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
options = _get_ssl_context_options(cf._verify_ssl_context) options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has not had any of the NO_TLS set. # The context has not had any of the NO_TLS set.
@ -147,11 +154,11 @@ class TLSConfigTests(TestCase):
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_whitelist_idna_failure(self): def test_whitelist_idna_failure(self) -> None:
""" """
The federation certificate whitelist will not allow IDNA domain names. The federation certificate whitelist will not allow IDNA domain names.
""" """
config = { config: JsonDict = {
"federation_certificate_verification_whitelist": [ "federation_certificate_verification_whitelist": [
"example.com", "example.com",
"*.ドメイン.テスト", "*.ドメイン.テスト",
@ -163,11 +170,11 @@ class TLSConfigTests(TestCase):
) )
self.assertIn("IDNA domain names", str(e)) self.assertIn("IDNA domain names", str(e))
def test_whitelist_idna_result(self): def test_whitelist_idna_result(self) -> None:
""" """
The federation certificate whitelist will match on IDNA encoded names. The federation certificate whitelist will match on IDNA encoded names.
""" """
config = { config: JsonDict = {
"federation_certificate_verification_whitelist": [ "federation_certificate_verification_whitelist": [
"example.com", "example.com",
"*.xn--eckwd4c7c.xn--zckzah", "*.xn--eckwd4c7c.xn--zckzah",
@ -176,14 +183,16 @@ class TLSConfigTests(TestCase):
t = TestConfig() t = TestConfig()
t.tls.read_config(config, config_dir_path="", data_dir_path="") t.tls.read_config(config, config_dir_path="", data_dir_path="")
cf = FederationPolicyForHTTPS(t) cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
# Not in the whitelist # Not in the whitelist
opts = cf.get_options(b"notexample.com") opts = cf.get_options(b"notexample.com")
assert isinstance(opts, SSLClientConnectionCreator)
self.assertTrue(opts._verifier._verify_certs) self.assertTrue(opts._verifier._verify_certs)
# Caught by the wildcard # Caught by the wildcard
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
assert isinstance(opts, SSLClientConnectionCreator)
self.assertFalse(opts._verifier._verify_certs) self.assertFalse(opts._verifier._verify_certs)
@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
"""get the options bits from an openssl context object""" """get the options bits from an openssl context object"""
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
# use the low-level interface # use the low-level interface
return SSL._lib.SSL_CTX_get_options(ssl_context._context) return SSL._lib.SSL_CTX_get_options(ssl_context._context) # type: ignore[attr-defined]

View File

@ -21,7 +21,7 @@ from tests.unittest import TestCase
class ValidateConfigTestCase(TestCase): class ValidateConfigTestCase(TestCase):
"""Test cases for synapse.config._util.validate_config""" """Test cases for synapse.config._util.validate_config"""
def test_bad_object_in_array(self): def test_bad_object_in_array(self) -> None:
"""malformed objects within an array should be validated correctly""" """malformed objects within an array should be validated correctly"""
# consider a structure: # consider a structure:

View File

@ -17,19 +17,20 @@ import tempfile
import unittest import unittest
from contextlib import redirect_stdout from contextlib import redirect_stdout
from io import StringIO from io import StringIO
from typing import List
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
class ConfigFileTestCase(unittest.TestCase): class ConfigFileTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
self.config_file = os.path.join(self.dir, "homeserver.yaml") self.config_file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self): def tearDown(self) -> None:
shutil.rmtree(self.dir) shutil.rmtree(self.dir)
def generate_config(self): def generate_config(self) -> None:
with redirect_stdout(StringIO()): with redirect_stdout(StringIO()):
HomeServerConfig.load_or_generate_config( HomeServerConfig.load_or_generate_config(
"", "",
@ -43,7 +44,7 @@ class ConfigFileTestCase(unittest.TestCase):
], ],
) )
def generate_config_and_remove_lines_containing(self, needle): def generate_config_and_remove_lines_containing(self, needle: str) -> None:
self.generate_config() self.generate_config()
with open(self.config_file) as f: with open(self.config_file) as f:
@ -52,7 +53,7 @@ class ConfigFileTestCase(unittest.TestCase):
with open(self.config_file, "w") as f: with open(self.config_file, "w") as f:
f.write("".join(contents)) f.write("".join(contents))
def add_lines_to_config(self, lines): def add_lines_to_config(self, lines: List[str]) -> None:
with open(self.config_file, "a") as f: with open(self.config_file, "a") as f:
for line in lines: for line in lines:
f.write(line + "\n") f.write(line + "\n")