Merge pull request #5523 from matrix-org/rav/arg_defaults

Stop conflating generated config and default config
This commit is contained in:
Richard van der Hoff 2019-06-24 17:24:35 +01:00 committed by GitHub
commit af8a962905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 88 additions and 107 deletions

1
changelog.d/5523.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a regression where homeservers on private IP addresses were incorrectly blacklisted.

View File

@ -136,11 +136,6 @@ class Config(object):
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
@staticmethod
def read_config_file(file_path):
with open(file_path) as file_stream:
return yaml.safe_load(file_stream)
def invoke_all(self, name, *args, **kargs): def invoke_all(self, name, *args, **kargs):
results = [] results = []
for cls in type(self).mro(): for cls in type(self).mro():
@ -158,9 +153,8 @@ class Config(object):
): ):
"""Build a default configuration file """Build a default configuration file
This is used both when the user explicitly asks us to generate a config file This is used when the user explicitly asks us to generate a config file
(eg with --generate_config), and before loading the config at runtime (to give (eg with --generate_config).
a base which the config files override)
Args: Args:
config_dir_path (str): The path where the config files are kept. Used to config_dir_path (str): The path where the config files are kept. Used to
@ -182,10 +176,10 @@ class Config(object):
Returns: Returns:
str: the yaml config file str: the yaml config file
""" """
default_config = "\n\n".join( return "\n\n".join(
dedent(conf) dedent(conf)
for conf in self.invoke_all( for conf in self.invoke_all(
"default_config", "generate_config_section",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
data_dir_path=data_dir_path, data_dir_path=data_dir_path,
server_name=server_name, server_name=server_name,
@ -194,8 +188,6 @@ class Config(object):
) )
) )
return default_config
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(cls, description, argv):
"""Parse the commandline and config files """Parse the commandline and config files
@ -240,9 +232,7 @@ class Config(object):
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
data_dir_path = os.getcwd() data_dir_path = os.getcwd()
config_dict = obj.read_config_files( config_dict = read_config_files(config_files)
config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
obj.parse_config_dict( obj.parse_config_dict(
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
) )
@ -354,8 +344,8 @@ class Config(object):
config_file.write("# vim:ft=yaml\n\n") config_file.write("# vim:ft=yaml\n\n")
config_file.write(config_str) config_file.write(config_str)
config = yaml.safe_load(config_str) config_dict = yaml.safe_load(config_str)
obj.invoke_all("generate_files", config) obj.generate_missing_files(config_dict, config_dir_path)
print( print(
( (
@ -385,12 +375,9 @@ class Config(object):
obj.invoke_all("add_arguments", parser) obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args) args = parser.parse_args(remaining_args)
config_dict = obj.read_config_files( config_dict = read_config_files(config_files)
config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
if generate_missing_configs: if generate_missing_configs:
obj.generate_missing_files(config_dict) obj.generate_missing_files(config_dict, config_dir_path)
return None return None
obj.parse_config_dict( obj.parse_config_dict(
@ -400,53 +387,6 @@ class Config(object):
return obj return obj
def read_config_files(self, config_files, config_dir_path, data_dir_path):
"""Read the config files into a dict
Args:
config_files (iterable[str]): A list of the config files to read
config_dir_path (str): The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
data_dir_path (str): The path where the data files are kept. Used to create
filenames for things like the database and media store.
Returns: dict
"""
# first we read the config files into a dict
specified_config = {}
for config_file in config_files:
yaml_config = self.read_config_file(config_file)
specified_config.update(yaml_config)
# not all of the options have sensible defaults in code, so we now need to
# generate a default config file suitable for the specified server name...
if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
config_string = self.generate_config(
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name,
generate_secrets=False,
)
# ... and read it into a base config dict ...
config = yaml.safe_load(config_string)
# ... and finally, overlay it with the actual configuration.
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
+ "\n"
+ MISSING_REPORT_STATS_SPIEL
)
return config
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
"""Read the information from the config dict into this Config object. """Read the information from the config dict into this Config object.
@ -466,8 +406,32 @@ class Config(object):
data_dir_path=data_dir_path, data_dir_path=data_dir_path,
) )
def generate_missing_files(self, config_dict): def generate_missing_files(self, config_dict, config_dir_path):
self.invoke_all("generate_files", config_dict) self.invoke_all("generate_files", config_dict, config_dir_path)
def read_config_files(config_files):
"""Read the config files into a dict
Args:
config_files (iterable[str]): A list of the config files to read
Returns: dict
"""
specified_config = {}
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
specified_config.update(yaml_config)
if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME)
if "report_stats" not in specified_config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
)
return specified_config
def find_config_files(search_paths): def find_config_files(search_paths):

View File

@ -30,7 +30,7 @@ class ApiConfig(Config):
], ],
) )
def default_config(cls, **kwargs): def generate_config_section(cls, **kwargs):
return """\ return """\
## API Configuration ## ## API Configuration ##

View File

@ -34,7 +34,7 @@ class AppServiceConfig(Config):
self.notify_appservices = config.get("notify_appservices", True) self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def default_config(cls, **kwargs): def generate_config_section(cls, **kwargs):
return """\ return """\
# A list of application service config files to use # A list of application service config files to use
# #

View File

@ -28,7 +28,7 @@ class CaptchaConfig(Config):
"https://www.recaptcha.net/recaptcha/api/siteverify", "https://www.recaptcha.net/recaptcha/api/siteverify",
) )
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this. # See docs/CAPTCHA_SETUP for full details of configuring this.

View File

@ -35,7 +35,7 @@ class CasConfig(Config):
self.cas_service_url = None self.cas_service_url = None
self.cas_required_attributes = {} self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable CAS for registration and login. # Enable CAS for registration and login.
# #

View File

@ -111,5 +111,5 @@ class ConsentConfig(Config):
"policy_name", "Privacy Policy" "policy_name", "Privacy Policy"
) )
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG return DEFAULT_CONFIG

View File

@ -38,7 +38,7 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
def default_config(self, data_dir_path, **kwargs): def generate_config_section(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db") database_path = os.path.join(data_dir_path, "homeserver.db")
return ( return (
"""\ """\

View File

@ -214,7 +214,7 @@ class EmailConfig(Config):
if not os.path.isfile(p): if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p,)) raise ConfigError("Unable to find email template file %s" % (p,))
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable sending emails for password resets, notification events or # Enable sending emails for password resets, notification events or
# account expiry notices # account expiry notices

View File

@ -21,7 +21,7 @@ class GroupsConfig(Config):
self.enable_group_creation = config.get("enable_group_creation", False) self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "") self.group_creation_prefix = config.get("group_creation_prefix", "")
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
# Uncomment to allow non-server-admin users to create groups on this server # Uncomment to allow non-server-admin users to create groups on this server
# #

View File

@ -41,7 +41,7 @@ class JWTConfig(Config):
self.jwt_secret = None self.jwt_secret = None
self.jwt_algorithm = None self.jwt_algorithm = None
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
# The JWT needs to contain a globally unique "sub" (subject) claim. # The JWT needs to contain a globally unique "sub" (subject) claim.
# #

View File

@ -65,13 +65,18 @@ class TrustedKeyServer(object):
class KeyConfig(Config): class KeyConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file # the signing key can be specified inline or in a separate file
if "signing_key" in config: if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]]) self.signing_key = read_signing_keys([config["signing_key"]])
else: else:
self.signing_key_path = config["signing_key_path"] signing_key_path = config.get("signing_key_path")
self.signing_key = self.read_signing_key(self.signing_key_path) if signing_key_path is None:
signing_key_path = os.path.join(
config_dir_path, config["server_name"] + ".signing.key"
)
self.signing_key = self.read_signing_key(signing_key_path)
self.old_signing_keys = self.read_old_signing_keys( self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {}) config.get("old_signing_keys", {})
@ -117,7 +122,7 @@ class KeyConfig(Config):
# falsification of values # falsification of values
self.form_secret = config.get("form_secret", None) self.form_secret = config.get("form_secret", None)
def default_config( def generate_config_section(
self, config_dir_path, server_name, generate_secrets=False, **kwargs self, config_dir_path, server_name, generate_secrets=False, **kwargs
): ):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
@ -237,8 +242,15 @@ class KeyConfig(Config):
) )
return keys return keys
def generate_files(self, config): def generate_files(self, config, config_dir_path):
signing_key_path = config["signing_key_path"] if "signing_key" in config:
return
signing_key_path = config.get("signing_key_path")
if signing_key_path is None:
signing_key_path = os.path.join(
config_dir_path, config["server_name"] + ".signing.key"
)
if not self.path_exists(signing_key_path): if not self.path_exists(signing_key_path):
print("Generating signing key file %s" % (signing_key_path,)) print("Generating signing key file %s" % (signing_key_path,))

View File

@ -80,7 +80,7 @@ class LoggingConfig(Config):
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config") log_config = os.path.join(config_dir_path, server_name + ".log.config")
return ( return (
"""\ """\
@ -133,7 +133,7 @@ class LoggingConfig(Config):
help="Do not redirect stdout/stderr to the log", help="Do not redirect stdout/stderr to the log",
) )
def generate_files(self, config): def generate_files(self, config, config_dir_path):
log_config = config.get("log_config") log_config = config.get("log_config")
if log_config and not os.path.exists(log_config): if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")

View File

@ -40,7 +40,7 @@ class MetricsConfig(Config):
"sentry.dsn field is required when sentry integration is enabled" "sentry.dsn field is required when sentry integration is enabled"
) )
def default_config(self, report_stats=None, **kwargs): def generate_config_section(self, report_stats=None, **kwargs):
res = """\ res = """\
## Metrics ### ## Metrics ###

View File

@ -28,7 +28,7 @@ class PasswordConfig(Config):
self.password_enabled = password_config.get("enabled", True) self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "") self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
password_config: password_config:
# Uncomment to disable password login # Uncomment to disable password login

View File

@ -46,7 +46,7 @@ class PasswordAuthProviderConfig(Config):
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
#password_providers: #password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider" # - module: "ldap_auth_provider.LdapAuthProvider"

View File

@ -42,7 +42,7 @@ class PushConfig(Config):
) )
self.push_include_content = not redact_content self.push_include_content = not redact_content
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# Clients requesting push notifications can either have the body of # Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details # the message sent in the notification poke along with other details

View File

@ -80,7 +80,7 @@ class RatelimitConfig(Config):
"federation_rr_transactions_per_room_per_second", 50 "federation_rr_transactions_per_room_per_second", 50
) )
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##

View File

@ -85,7 +85,7 @@ class RegistrationConfig(Config):
"disable_msisdn_registration", False "disable_msisdn_registration", False
) )
def default_config(self, generate_secrets=False, **kwargs): def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets: if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % ( registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50), random_string_with_symbols(50),

View File

@ -91,7 +91,9 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
backup_media_store_path = config.get("backup_media_store_path") backup_media_store_path = config.get("backup_media_store_path")
@ -148,7 +150,7 @@ class ContentRepositoryConfig(Config):
(provider_class, parsed_config, wrapper_config) (provider_class, parsed_config, wrapper_config)
) )
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads"))
self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False)
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES)
@ -188,7 +190,7 @@ class ContentRepositoryConfig(Config):
self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
def default_config(self, data_dir_path, **kwargs): def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store") media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads") uploads_path = os.path.join(data_dir_path, "uploads")

View File

@ -46,7 +46,7 @@ class RoomDirectoryConfig(Config):
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
] ]
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# Uncomment to disable searching the public room list. When disabled # Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote # blocks searching local and remote room lists for local and remote

View File

@ -61,7 +61,7 @@ class SAML2Config(Config):
}, },
} }
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
# Enable SAML2 for registration and login. Uses pysaml2. # Enable SAML2 for registration and login. Uses pysaml2.
# #

View File

@ -327,7 +327,7 @@ class ServerConfig(Config):
def has_tls_listener(self): def has_tls_listener(self):
return any(l["tls"] for l in self.listeners) return any(l["tls"] for l in self.listeners)
def default_config(self, server_name, data_dir_path, **kwargs): def generate_config_section(self, server_name, data_dir_path, **kwargs):
_, bind_port = parse_and_validate_server_name(server_name) _, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None: if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400

View File

@ -78,5 +78,5 @@ class ServerNoticesConfig(Config):
# todo: i18n # todo: i18n
self.server_notices_room_name = c.get("room_name", "Server Notices") self.server_notices_room_name = c.get("room_name", "Server Notices")
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG return DEFAULT_CONFIG

View File

@ -26,7 +26,7 @@ class SpamCheckerConfig(Config):
if provider is not None: if provider is not None:
self.spam_checker = load_module(provider) self.spam_checker = load_module(provider)
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
#spam_checker: #spam_checker:
# module: "my_custom_project.SuperSpamChecker" # module: "my_custom_project.SuperSpamChecker"

View File

@ -42,7 +42,7 @@ class StatsConfig(Config):
/ 1000 / 1000
) )
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# Local statistics collection. Used in populating the room directory. # Local statistics collection. Used in populating the room directory.
# #

View File

@ -26,7 +26,7 @@ class ThirdPartyRulesConfig(Config):
if provider is not None: if provider is not None:
self.third_party_event_rules = load_module(provider) self.third_party_event_rules = load_module(provider)
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
# Server admins can define a Python module that implements extra rules for # Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to # allowing or denying incoming events. In order to work, this module needs to

View File

@ -217,7 +217,9 @@ class TlsConfig(Config):
if sha256_fingerprint not in sha256_fingerprints: if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({"sha256": sha256_fingerprint}) self.tls_fingerprints.append({"sha256": sha256_fingerprint})
def default_config(self, config_dir_path, server_name, data_dir_path, **kwargs): def generate_config_section(
self, config_dir_path, server_name, data_dir_path, **kwargs
):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt" tls_certificate_path = base_key_name + ".tls.crt"

View File

@ -33,7 +33,7 @@ class UserDirectoryConfig(Config):
"search_all_users", False "search_all_users", False
) )
def default_config(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# User Directory configuration # User Directory configuration
# #

View File

@ -26,7 +26,7 @@ class VoipConfig(Config):
) )
self.turn_allow_guests = config.get("turn_allow_guests", True) self.turn_allow_guests = config.get("turn_allow_guests", True)
def default_config(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## TURN ## ## TURN ##