Don't require config to create database

This commit is contained in:
Erik Johnston 2016-04-06 14:08:18 +01:00
parent 2e308a3a38
commit 8aab9d87fa
13 changed files with 69 additions and 86 deletions

View File

@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"],
} }
@ -292,7 +294,7 @@ class Porter(object):
} }
) )
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=None)
db_conn.commit() db_conn.commit()
@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine(FakeConfig(sqlite_config)) sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(FakeConfig(postgres_config)) postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View File

@ -33,7 +33,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self): def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should # Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine. # not be passed to the database engine.
db_params = { db_params = {
@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer):
} }
db_conn = self.database_engine.module.connect(**db_params) db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn) if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn return db_conn
@ -386,7 +387,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -402,8 +403,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = hs.get_db_conn() db_conn = hs.get_db_conn(run_new_connection=False)
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
db_conn.commit() db_conn.commit()

View File

@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
} }
def create_engine(config): def create_engine(database_config):
name = config.database_config["name"] name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None) engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if engine_class:
module = importlib.import_module(name) module = importlib.import_module(name)
return engine_class(module, config=config) return engine_class(module)
raise RuntimeError( raise RuntimeError(
"Unsupported database engine '%s'" % (name,) "Unsupported database engine '%s'" % (name,)

View File

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import prepare_database
from ._base import IncorrectDatabaseSetup from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False single_threaded = False
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn): def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING") txn.execute("SHOW SERVER_ENCODING")
@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
def prepare_database(self, db_conn):
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"] return error.pgcode in ["40001", "40P01"]

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import ( from synapse.storage.prepare_database import prepare_database
prepare_database, prepare_sqlite3_database
)
import struct import struct
@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.config = config
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
self.prepare_database(db_conn) prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank) db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
return False return False

View File

@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config): def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info: if version_info:
user_version, delta_files, upgraded = version_info user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine, config)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) if config is None:
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
raise UpgradeDatabaseException("Database needs to be upgraded")
else:
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine)
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config):
raise raise
def _setup_new_database(cur, database_engine, config): def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then """Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas. applying any necessary deltas.
@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[], applied_delta_files=[],
upgraded=False, upgraded=False,
database_engine=database_engine, database_engine=database_engine,
config=config, config=None,
is_empty=True,
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config): upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine, config=config) module.run_create(cur, database_engine)
if not is_empty:
module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc": elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've # Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package # disabled their generation; e.g. from distribution package
@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View File

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs): def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex") cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall(): for row in cur.fetchall():
try: try:
@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?", "UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0]) (new_regex, row[0])
) )
def run_upgrade(*args, **kwargs):
pass

View File

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...") logger.info("Porting pushers table...")
cur.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 ( CREATE TABLE IF NOT EXISTS pushers2 (
@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers") cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count) logger.info("Moved %d pushers to new table", count)
def run_upgrade(*args, **kwargs):
pass

View File

@ -43,7 +43,7 @@ SQLITE_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()): for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json)) cur.execute(sql, ("event_search", progress_json))
def run_upgrade(*args, **kwargs):
pass

View File

@ -27,7 +27,7 @@ ALTER_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()): for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json)) cur.execute(sql, ("event_origin_server_ts", progress_json))
def run_upgrade(*args, **kwargs):
pass

View File

@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, config, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice. # NULL indicates user was not registered by an appservice.
try: try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so... # Maybe we already added the column? Hope so...
pass pass
def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users") cur.execute("SELECT name FROM users")
rows = cur.fetchall() rows = cur.fetchall()

View File

@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test", "test",
db_pool=self.db_pool, db_pool=self.db_pool,
config=config, config=config,
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
) )
self.datastore = SQLBaseStore(hs) self.datastore = SQLBaseStore(hs)

View File

@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, db_pool=db_pool, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn, get_db_conn=db_pool.get_db_conn,
**kargs **kargs
) )
@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
**kargs **kargs
) )
@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn return conn
def create_engine(self): def create_engine(self):
return create_engine(self.config) return create_engine(self.config.database_config)
class MemoryDataStore(object): class MemoryDataStore(object):