Neater database setup at application startup time; only .connect() it once, not once per schema file; don't build the db_pool twice

This commit is contained in:
Paul "LeoNerd" Evans 2014-08-20 16:40:51 +01:00
parent a8774cf351
commit 648796ef1d

View File

@ -43,6 +43,17 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
@ -65,24 +76,11 @@ class SynapseHomeServer(HomeServer):
don't have to worry about overwriting existing content. don't have to worry about overwriting existing content.
""" """
logging.info("Preparing database: %s...", self.db_name) logging.info("Preparing database: %s...", self.db_name)
pool = adbapi.ConnectionPool(
'sqlite3', self.db_name, check_same_thread=False,
cp_min=1, cp_max=1)
schemas = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
for sql_loc in schemas:
sql_script = read_schema(sql_loc)
with sqlite3.connect(self.db_name) as db_conn: with sqlite3.connect(self.db_name) as db_conn:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c = db_conn.cursor() c = db_conn.cursor()
c.executescript(sql_script) c.executescript(sql_script)
c.close() c.close()
@ -90,6 +88,10 @@ class SynapseHomeServer(HomeServer):
logging.info("Database prepared in %s.", self.db_name) logging.info("Database prepared in %s.", self.db_name)
pool = adbapi.ConnectionPool(
'sqlite3', self.db_name, check_same_thread=False,
cp_min=1, cp_max=1)
return pool return pool
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
@ -282,7 +284,7 @@ def setup():
redirect_root_to_web_client=True) redirect_root_to_web_client=True)
hs.start_listening(args.port) hs.start_listening(args.port)
hs.build_db_pool() hs.get_db_pool()
if args.manhole: if args.manhole:
f = twisted.manhole.telnet.ShellFactory() f = twisted.manhole.telnet.ShellFactory()