Make database selection configurable

This commit is contained in:
Erik Johnston 2015-03-20 10:55:55 +00:00
parent 0d0610870d
commit 455579ca90
2 changed files with 42 additions and 11 deletions

View File

@ -61,6 +61,7 @@ import resource
import subprocess import subprocess
import sqlite3 import sqlite3
import syweb import syweb
import yaml
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -108,14 +109,14 @@ class SynapseHomeServer(HomeServer):
return None return None
def build_db_pool(self): def build_db_pool(self):
return adbapi.ConnectionPool( name = self.db_config.pop("name", None)
"sqlite3", self.get_db_name(), if name == "MySQLdb":
check_same_thread=False, return adbapi.ConnectionPool(
cp_min=1, name,
cp_max=1, **self.db_config
cp_openfun=prepare_database, # Prepare the database for each conn )
# so that :memory: sqlite works
) raise RuntimeError("Unsupported database type")
def create_resource_tree(self, redirect_root_to_web_client): def create_resource_tree(self, redirect_root_to_web_client):
"""Create the resource tree for this Home Server. """Create the resource tree for this Home Server.
@ -357,11 +358,29 @@ def setup(config_options):
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
if config.database_config:
with open(config.database_config, 'r') as f:
db_config = yaml.safe_load(f)
name = db_config.get("name", None)
if name == "MySQLdb":
db_config.update({
"sql_mode": "TRADITIONAL",
"charset": "utf8",
"use_unicode": True,
})
else:
db_config = {
"name": "sqlite3",
"database": config.database_path,
}
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port, domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
db_config=db_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
@ -377,9 +396,12 @@ def setup(config_options):
logger.info("Preparing database: %s...", db_name) logger.info("Preparing database: %s...", db_name)
try: try:
with sqlite3.connect(db_name) as db_conn: # with sqlite3.connect(db_name) as db_conn:
prepare_sqlite3_database(db_conn) # prepare_sqlite3_database(db_conn)
prepare_database(db_conn) # prepare_database(db_conn)
import MySQLdb
db_conn = MySQLdb.connect(**db_config)
prepare_database(db_conn)
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
"\nFailed to upgrade database.\n" "\nFailed to upgrade database.\n"

View File

@ -26,6 +26,11 @@ class DatabaseConfig(Config):
self.database_path = self.abspath(args.database_path) self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size) self.event_cache_size = self.parse_size(args.event_cache_size)
if args.database_config:
self.database_config = self.abspath(args.database_config)
else:
self.database_config = None
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser) super(DatabaseConfig, cls).add_arguments(parser)
@ -38,6 +43,10 @@ class DatabaseConfig(Config):
"--event-cache-size", default="100K", "--event-cache-size", default="100K",
help="Number of events to cache in memory." help="Number of events to cache in memory."
) )
db_group.add_argument(
"--database-config", default=None,
help="Location of the database configuration file."
)
@classmethod @classmethod
def generate_config(cls, args, config_dir_path): def generate_config(cls, args, config_dir_path):