Add a script to generate a clean config file (#4315)

This commit is contained in:
Richard van der Hoff 2018-12-21 16:04:57 +01:00 committed by Amber Brown
parent f3561f8d86
commit 9c2af7b2c5
11 changed files with 157 additions and 46 deletions

1
changelog.d/4315.feature Normal file
View File

@ -0,0 +1 @@
Add a script to generate a clean config file

67
scripts/generate_config Executable file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python
import argparse
import sys
from synapse.config.homeserver import HomeServerConfig
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
default="CONFDIR",
help="The path where the config files are kept. Used to create filenames for "
"things like the log config and the signing key. Default: %(default)s",
)
parser.add_argument(
"--data-dir",
default="DATADIR",
help="The path where the data files are kept. Used to create filenames for "
"things like the database and media store. Default: %(default)s",
)
parser.add_argument(
"--server-name",
default="SERVERNAME",
help="The server name. Used to initialise the server_name config param, but also "
"used in the names of some of the config files. Default: %(default)s",
)
parser.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"],
)
parser.add_argument(
"--generate-secrets",
action="store_true",
help="Enable generation of new secrets for things like the macaroon_secret_key."
"By default, these parameters will be left unset."
)
parser.add_argument(
"-o", "--output-file",
type=argparse.FileType('w'),
default=sys.stdout,
help="File to write the configuration to. Default: stdout",
)
args = parser.parse_args()
report_stats = args.report_stats
if report_stats is not None:
report_stats = report_stats == "yes"
conf = HomeServerConfig().generate_config(
config_dir_path=args.config_dir,
data_dir_path=args.data_dir,
server_name=args.server_name,
generate_secrets=args.generate_secrets,
report_stats=report_stats,
)
args.output_file.write(conf)

View File

@ -134,10 +134,6 @@ class Config(object):
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
@staticmethod
def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name))
@staticmethod @staticmethod
def read_config_file(file_path): def read_config_file(file_path):
with open(file_path) as file_stream: with open(file_path) as file_stream:
@ -151,8 +147,39 @@ class Config(object):
return results return results
def generate_config( def generate_config(
self, config_dir_path, server_name, is_generating_file, report_stats=None self,
config_dir_path,
data_dir_path,
server_name,
generate_secrets=False,
report_stats=None,
): ):
"""Build a default configuration file
This is used both when the user explicitly asks us to generate a config file
(eg with --generate_config), and before loading the config at runtime (to give
a base which the config files override)
Args:
config_dir_path (str): The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
data_dir_path (str): The path where the data files are kept. Used to create
filenames for things like the database and media store.
server_name (str): The server name. Used to initialise the server_name
config param, but also used in the names of some of the config files.
generate_secrets (bool): True if we should generate new secrets for things
like the macaroon_secret_key. If False, these parameters will be left
unset.
report_stats (bool|None): Initial setting for the report_stats setting.
If None, report_stats will be left unset.
Returns:
str: the yaml config file
"""
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join( default_config += "\n\n".join(
@ -160,15 +187,14 @@ class Config(object):
for conf in self.invoke_all( for conf in self.invoke_all(
"default_config", "default_config",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=is_generating_file, generate_secrets=generate_secrets,
report_stats=report_stats, report_stats=report_stats,
) )
) )
config = yaml.load(default_config) return default_config
return default_config, config
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(cls, description, argv):
@ -274,12 +300,14 @@ class Config(object):
if not cls.path_exists(config_dir_path): if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "w") as config_file: with open(config_path, "w") as config_file:
config_str, config = obj.generate_config( config_str = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True, generate_secrets=True,
) )
config = yaml.load(config_str)
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_str) config_file.write(config_str)
print( print(
@ -350,11 +378,13 @@ class Config(object):
raise ConfigError(MISSING_SERVER_NAME) raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = self.generate_config( config_string = self.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name, server_name=server_name,
is_generating_file=False, generate_secrets=False,
) )
config = yaml.load(config_string)
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from ._base import Config from ._base import Config
@ -45,8 +46,8 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
def default_config(self, **kwargs): def default_config(self, data_dir_path, **kwargs):
database_path = self.abspath("homeserver.db") database_path = os.path.join(data_dir_path, "homeserver.db")
return """\ return """\
# Database configuration # Database configuration
database: database:

View File

@ -53,10 +53,3 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
ServerNoticesConfig, RoomDirectoryConfig, ServerNoticesConfig, RoomDirectoryConfig,
): ):
pass pass
if __name__ == '__main__':
import sys
sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
)

View File

@ -66,26 +66,35 @@ class KeyConfig(Config):
# falsification of values # falsification of values
self.form_secret = config.get("form_secret", None) self.form_secret = config.get("form_secret", None)
def default_config(self, config_dir_path, server_name, is_generating_file=False, def default_config(self, config_dir_path, server_name, generate_secrets=False,
**kwargs): **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file: if generate_secrets:
macaroon_secret_key = random_string_with_symbols(50) macaroon_secret_key = 'macaroon_secret_key: "%s"' % (
form_secret = '"%s"' % random_string_with_symbols(50) random_string_with_symbols(50),
)
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
else: else:
macaroon_secret_key = None macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>"
form_secret = 'null' form_secret = "# form_secret: <PRIVATE STRING>"
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s" # a secret which is used to sign access tokens. If none is specified,
# the registration_shared_secret is used, if one is given; otherwise,
# a secret key is derived from the signing key.
#
# Note that changing this will invalidate any active access tokens, so
# all clients will have to log back in.
%(macaroon_secret_key)s
# Used to enable access token expiration. # Used to enable access token expiration.
expire_access_token: False expire_access_token: False
# a secret which is used to calculate HMACs for form values, to stop # a secret which is used to calculate HMACs for form values, to stop
# falsification of values # falsification of values. Must be specified for the User Consent
form_secret: %(form_secret)s # forms to work.
%(form_secret)s
## Signing Keys ## ## Signing Keys ##

View File

@ -80,9 +80,7 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
log_config = self.abspath( log_config = os.path.join(config_dir_path, server_name + ".log.config")
os.path.join(config_dir_path, server_name + ".log.config")
)
return """ return """
# A yaml python logging config file # A yaml python logging config file
log_config: "%(log_config)s" log_config: "%(log_config)s"

View File

@ -24,10 +24,16 @@ class MetricsConfig(Config):
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, report_stats=None, **kwargs): def default_config(self, report_stats=None, **kwargs):
suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" res = """\
return ("""\
## Metrics ### ## Metrics ###
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
""" + suffix) % locals() """
if report_stats is None:
res += "# report_stats: true|false\n"
else:
res += "report_stats: %s\n" % ('true' if report_stats else 'false')
return res

View File

@ -50,8 +50,13 @@ class RegistrationConfig(Config):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
def default_config(self, **kwargs): def default_config(self, generate_secrets=False, **kwargs):
registration_shared_secret = random_string_with_symbols(50) if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
)
else:
registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
return """\ return """\
## Registration ## ## Registration ##
@ -78,7 +83,7 @@ class RegistrationConfig(Config):
# If set, allows registration by anyone who also has the shared # If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" %(registration_shared_secret)s
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import namedtuple from collections import namedtuple
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -175,9 +175,9 @@ class ContentRepositoryConfig(Config):
"url_preview_url_blacklist", () "url_preview_url_blacklist", ()
) )
def default_config(self, **kwargs): def default_config(self, data_dir_path, **kwargs):
media_store = self.default_path("media_store") media_store = os.path.join(data_dir_path, "media_store")
uploads_path = self.default_path("uploads") uploads_path = os.path.join(data_dir_path, "uploads")
return r""" return r"""
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os.path
from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.endpoint import parse_and_validate_server_name
@ -203,7 +204,7 @@ class ServerConfig(Config):
] ]
}) })
def default_config(self, server_name, **kwargs): def default_config(self, server_name, data_dir_path, **kwargs):
_, bind_port = parse_and_validate_server_name(server_name) _, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None: if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400
@ -211,7 +212,7 @@ class ServerConfig(Config):
bind_port = 8448 bind_port = 8448
unsecure_port = 8008 unsecure_port = 8008
pid_file = self.abspath("homeserver.pid") pid_file = os.path.join(data_dir_path, "homeserver.pid")
return """\ return """\
## Server ## ## Server ##