mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Run database check before daemonizing, at the cost of database hygiene.
This commit is contained in:
parent
f8152f2708
commit
b02e1006b9
@ -18,7 +18,8 @@ import sys
|
|||||||
sys.dont_write_bytecode = True
|
sys.dont_write_bytecode = True
|
||||||
|
|
||||||
from synapse.storage import (
|
from synapse.storage import (
|
||||||
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
|
prepare_database, prepare_sqlite3_database, are_all_users_on_domain,
|
||||||
|
UpgradeDatabaseException,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -242,10 +243,9 @@ class SynapseHomeServer(HomeServer):
|
|||||||
)
|
)
|
||||||
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def run_startup_checks(self, db_conn):
|
||||||
def post_startup_check(self):
|
all_users_native = are_all_users_on_domain(
|
||||||
all_users_native = yield self.get_datastore().are_all_users_on_domain(
|
db_conn, self.hostname
|
||||||
self.hostname
|
|
||||||
)
|
)
|
||||||
if not all_users_native:
|
if not all_users_native:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
@ -254,9 +254,9 @@ class SynapseHomeServer(HomeServer):
|
|||||||
"Found users in database not native to %s!\n"
|
"Found users in database not native to %s!\n"
|
||||||
"You cannot changed a synapse server_name after it's been configured\n"
|
"You cannot changed a synapse server_name after it's been configured\n"
|
||||||
"******************************************************\n"
|
"******************************************************\n"
|
||||||
"\n"
|
"\n" % (self.hostname,)
|
||||||
)
|
)
|
||||||
reactor.stop()
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def get_version_string():
|
def get_version_string():
|
||||||
@ -392,6 +392,7 @@ def setup(config_options):
|
|||||||
with sqlite3.connect(db_name) as db_conn:
|
with sqlite3.connect(db_name) as db_conn:
|
||||||
prepare_sqlite3_database(db_conn)
|
prepare_sqlite3_database(db_conn)
|
||||||
prepare_database(db_conn)
|
prepare_database(db_conn)
|
||||||
|
hs.run_startup_checks(db_conn)
|
||||||
except UpgradeDatabaseException:
|
except UpgradeDatabaseException:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
"\nFailed to upgrade database.\n"
|
"\nFailed to upgrade database.\n"
|
||||||
@ -416,8 +417,6 @@ def setup(config_options):
|
|||||||
hs.get_datastore().start_profiling()
|
hs.get_datastore().start_profiling()
|
||||||
hs.get_replication_layer().start_get_pdu_cache()
|
hs.get_replication_layer().start_get_pdu_cache()
|
||||||
|
|
||||||
reactor.callWhenRunning(hs.post_startup_check)
|
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
@ -421,3 +421,13 @@ def prepare_sqlite3_database(db_conn):
|
|||||||
" VALUES (?,?)",
|
" VALUES (?,?)",
|
||||||
(row[0], False)
|
(row[0], False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def are_all_users_on_domain(txn, domain):
|
||||||
|
sql = "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
|
||||||
|
pat = "%:" + domain
|
||||||
|
cursor = txn.execute(sql, (pat,))
|
||||||
|
num_not_matching = cursor.fetchall()[0][0]
|
||||||
|
if num_not_matching == 0:
|
||||||
|
return True
|
||||||
|
return False
|
@ -144,21 +144,3 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
raise StoreError(404, "Token not found.")
|
raise StoreError(404, "Token not found.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def are_all_users_on_domain(self, domain):
|
|
||||||
res = yield self.runInteraction(
|
|
||||||
"are_all_users_on_domain",
|
|
||||||
self._are_all_users_on_domain_txn,
|
|
||||||
domain
|
|
||||||
)
|
|
||||||
defer.returnValue(res)
|
|
||||||
|
|
||||||
def _are_all_users_on_domain_txn(self, txn, domain):
|
|
||||||
sql = "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
|
|
||||||
pat = "%:" + domain
|
|
||||||
cursor = txn.execute(sql, (pat,))
|
|
||||||
num_not_matching = cursor.fetchall()[0][0]
|
|
||||||
if num_not_matching == 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
Loading…
Reference in New Issue
Block a user