diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index d675d8c8f..b63ecd4b5 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import read_schema +from synapse.storage import prepare_database from synapse.server import HomeServer @@ -36,7 +36,6 @@ from daemonize import Daemonize import twisted.manhole.telnet import logging -import sqlite3 import os import re import sys @@ -44,22 +43,6 @@ import sys logger = logging.getLogger(__name__) -SCHEMAS = [ - "transactions", - "pdu", - "users", - "profiles", - "presence", - "im", - "room_aliases", -] - - -# Remember to update this number every time an incompatible change is made to -# database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 3 - - class SynapseHomeServer(HomeServer): def build_http_client(self): @@ -80,52 +63,12 @@ class SynapseHomeServer(HomeServer): ) def build_db_pool(self): - """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we - don't have to worry about overwriting existing content. - """ - logging.info("Preparing database: %s...", self.db_name) - - with sqlite3.connect(self.db_name) as db_conn: - c = db_conn.cursor() - c.execute("PRAGMA user_version") - row = c.fetchone() - - if row and row[0]: - user_version = row[0] - - if user_version > SCHEMA_VERSION: - raise ValueError("Cannot use this database as it is too " + - "new for the server to understand" - ) - elif user_version < SCHEMA_VERSION: - logging.info("Upgrading database from version %d", - user_version - ) - - # Run every version since after the current version. - for v in range(user_version + 1, SCHEMA_VERSION + 1): - sql_script = read_schema("delta/v%d" % (v)) - c.executescript(sql_script) - - db_conn.commit() - - else: - for sql_loc in SCHEMAS: - sql_script = read_schema(sql_loc) - - c.executescript(sql_script) - db_conn.commit() - c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) - - c.close() - - logging.info("Database prepared in %s.", self.db_name) - - pool = adbapi.ConnectionPool( - 'sqlite3', self.db_name, check_same_thread=False, - cp_min=1, cp_max=1) - - return pool + return adbapi.ConnectionPool( + "sqlite3", self.get_db_name(), + check_same_thread=False, + cp_min=1, + cp_max=1 + ) def create_resource_tree(self, web_client, redirect_root_to_web_client): """Create the resource tree for this Home Server. @@ -270,6 +213,8 @@ def setup(): ) hs.start_listening(config.bind_port, config.unsecure_port) + prepare_database(hs.get_db_name()) + hs.get_db_pool() if config.manhole: diff --git a/synapse/server.py b/synapse/server.py index 83368ea5a..1ba13f3df 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -57,6 +57,7 @@ class BaseHomeServer(object): DEPENDENCIES = [ 'clock', 'http_client', + 'db_name', 'db_pool', 'persistence_service', 'replication_layer', diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ad2a484c1..2543fb12b 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -43,10 +43,28 @@ from .keys import KeyStore import json import logging import os +import sqlite3 logger = logging.getLogger(__name__) + +SCHEMAS = [ + "transactions", + "pdu", + "users", + "profiles", + "presence", + "im", + "room_aliases", +] + + +# Remember to update this number every time an incompatible change is made to +# database schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 3 + + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying something went wrong. @@ -350,3 +368,46 @@ def read_schema(schema): """ with open(schema_path(schema)) as schema_file: return schema_file.read() + + +def prepare_database(db_name): + """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we + don't have to worry about overwriting existing content. + """ + logging.info("Preparing database: %s...", db_name) + + with sqlite3.connect(db_name) as db_conn: + c = db_conn.cursor() + c.execute("PRAGMA user_version") + row = c.fetchone() + + if row and row[0]: + user_version = row[0] + + if user_version > SCHEMA_VERSION: + raise ValueError("Cannot use this database as it is too " + + "new for the server to understand" + ) + elif user_version < SCHEMA_VERSION: + logging.info("Upgrading database from version %d", + user_version + ) + + # Run every version since after the current version. + for v in range(user_version + 1, SCHEMA_VERSION + 1): + sql_script = read_schema("delta/v%d" % (v)) + c.executescript(sql_script) + + db_conn.commit() + + else: + for sql_loc in SCHEMAS: + sql_script = read_schema(sql_loc) + + c.executescript(sql_script) + db_conn.commit() + c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) + + c.close() + + logging.info("Database prepared in %s.", db_name)