Port synapse_port_db to async/await (#6718)

* Raise an exception if there are pending background updates

So we return with a non-0 code

* Changelog

* Port synapse_port_db to async/await

* Port update_database to async/await

* Add version string to mocked homeservers

* Remove unused imports

* Convert overseen bits to async/await

* Fixup logging contexts

* Fix imports

* Add a way to print an error without raising an exception

* Incorporate review
This commit is contained in:
Brendan Abolivier 2020-01-21 19:04:58 +00:00 committed by GitHub
parent 0e68760078
commit 07124d028d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 89 deletions

1
changelog.d/6718.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug causing the `synapse_port_db` script to return 0 in a specific error case.

View File

@ -22,10 +22,12 @@ import yaml
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
import synapse
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("update_database") logger = logging.getLogger("update_database")
@ -38,6 +40,8 @@ class MockHomeserver(HomeServer):
config.server_name, reactor=reactor, config=config, **kwargs config.server_name, reactor=reactor, config=config, **kwargs
) )
self.version_string = "Synapse/"+get_version_string(synapse)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -81,15 +85,17 @@ if __name__ == "__main__":
hs.setup() hs.setup()
store = hs.get_datastore() store = hs.get_datastore()
@defer.inlineCallbacks async def run_background_updates():
def run_background_updates(): await store.db.updates.run_background_updates(sleep=False)
yield store.db.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run. # Stop the reactor to exit the script once every background update is run.
reactor.stop() reactor.stop()
# Apply all background updates on the database. def run():
reactor.callWhenRunning( # Apply all background updates on the database.
lambda: run_as_background_process("background_updates", run_background_updates) defer.ensureDeferred(
) run_as_background_process("background_updates", run_background_updates)
)
reactor.callWhenRunning(run)
reactor.run() reactor.run()

View File

@ -27,13 +27,16 @@ from six import string_types
import yaml import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
import synapse
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import (
from synapse.storage._base import LoggingTransaction LoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.data_stores.main.deviceinbox import ( from synapse.storage.data_stores.main.deviceinbox import (
DeviceInboxBackgroundUpdateStore, DeviceInboxBackgroundUpdateStore,
@ -61,6 +64,7 @@ from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock from synapse.util import Clock
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse_port_db") logger = logging.getLogger("synapse_port_db")
@ -125,6 +129,13 @@ APPEND_ONLY_TABLES = [
] ]
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None end_error_exec_info = None
@ -177,6 +188,7 @@ class MockHomeserver:
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.config = config self.config = config
self.hostname = config.server_name self.hostname = config.server_name
self.version_string = "Synapse/"+get_version_string(synapse)
def get_clock(self): def get_clock(self):
return self.clock return self.clock
@ -189,11 +201,10 @@ class Porter(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
@defer.inlineCallbacks async def setup_table(self, table):
def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = yield self.postgres_store.db.simple_select_one( row = await self.postgres_store.db.simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcols=("forward_rowid", "backward_rowid"),
@ -207,10 +218,10 @@ class Porter(object):
forward_chunk, forward_chunk,
already_ported, already_ported,
total_to_port, total_to_port,
) = yield self._setup_sent_transactions() ) = await self._setup_sent_transactions()
backward_chunk = 0 backward_chunk = 0
else: else:
yield self.postgres_store.db.simple_insert( await self.postgres_store.db.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": table, "table_name": table,
@ -227,7 +238,7 @@ class Porter(object):
backward_chunk = row["backward_rowid"] backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = await self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, forward_chunk, backward_chunk
) )
else: else:
@ -238,9 +249,9 @@ class Porter(object):
) )
txn.execute("TRUNCATE %s CASCADE" % (table,)) txn.execute("TRUNCATE %s CASCADE" % (table,))
yield self.postgres_store.execute(delete_all) await self.postgres_store.execute(delete_all)
yield self.postgres_store.db.simple_insert( await self.postgres_store.db.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
) )
@ -248,16 +259,13 @@ class Porter(object):
forward_chunk = 1 forward_chunk = 1
backward_chunk = 0 backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = await self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, forward_chunk, backward_chunk
) )
defer.returnValue( return table, already_ported, total_to_port, forward_chunk, backward_chunk
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks async def handle_table(
def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk self, table, postgres_size, table_size, forward_chunk, backward_chunk
): ):
logger.info( logger.info(
@ -275,7 +283,7 @@ class Porter(object):
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table( await self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk postgres_size, table_size, forward_chunk, backward_chunk
) )
return return
@ -294,7 +302,7 @@ class Porter(object):
if table == "user_directory_stream_pos": if table == "user_directory_stream_pos":
# We need to make sure there is a single row, `(X, null), as that is # We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there. # what synapse expects to be there.
yield self.postgres_store.db.simple_insert( await self.postgres_store.db.simple_insert(
table=table, values={"stream_id": None} table=table, values={"stream_id": None}
) )
self.progress.update(table, table_size) # Mark table as done self.progress.update(table, table_size) # Mark table as done
@ -335,7 +343,7 @@ class Porter(object):
return headers, forward_rows, backward_rows return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.db.runInteraction( headers, frows, brows = await self.sqlite_store.db.runInteraction(
"select", r "select", r
) )
@ -361,7 +369,7 @@ class Porter(object):
}, },
) )
yield self.postgres_store.execute(insert) await self.postgres_store.execute(insert)
postgres_size += len(rows) postgres_size += len(rows)
@ -369,8 +377,7 @@ class Porter(object):
else: else:
return return
@defer.inlineCallbacks async def handle_search_table(
def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk self, postgres_size, table_size, forward_chunk, backward_chunk
): ):
select = ( select = (
@ -390,7 +397,7 @@ class Porter(object):
return headers, rows return headers, rows
headers, rows = yield self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
@ -438,7 +445,7 @@ class Porter(object):
}, },
) )
yield self.postgres_store.execute(insert) await self.postgres_store.execute(insert)
postgres_size += len(rows) postgres_size += len(rows)
@ -476,11 +483,10 @@ class Porter(object):
return store return store
@defer.inlineCallbacks async def run_background_updates_on_postgres(self):
def run_background_updates_on_postgres(self):
# Manually apply all background updates on the PostgreSQL database. # Manually apply all background updates on the PostgreSQL database.
postgres_ready = ( postgres_ready = (
yield self.postgres_store.db.updates.has_completed_background_updates() await self.postgres_store.db.updates.has_completed_background_updates()
) )
if not postgres_ready: if not postgres_ready:
@ -489,13 +495,20 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL") self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready: while not postgres_ready:
yield self.postgres_store.db.updates.do_next_background_update(100) await self.postgres_store.db.updates.do_next_background_update(100)
postgres_ready = yield ( postgres_ready = await (
self.postgres_store.db.updates.has_completed_background_updates() self.postgres_store.db.updates.has_completed_background_updates()
) )
@defer.inlineCallbacks async def run(self):
def run(self): """Ports the SQLite database to a PostgreSQL database.
When a fatal error is met, its message is assigned to the global "end_error"
variable. When this error comes with a stacktrace, its exec_info is assigned to
the global "end_error_exec_info" variable.
"""
global end_error
try: try:
# we allow people to port away from outdated versions of sqlite. # we allow people to port away from outdated versions of sqlite.
self.sqlite_store = self.build_db_store( self.sqlite_store = self.build_db_store(
@ -505,21 +518,21 @@ class Porter(object):
# Check if all background updates are done, abort if not. # Check if all background updates are done, abort if not.
updates_complete = ( updates_complete = (
yield self.sqlite_store.db.updates.has_completed_background_updates() await self.sqlite_store.db.updates.has_completed_background_updates()
) )
if not updates_complete: if not updates_complete:
sys.stderr.write( end_error = (
"Pending background updates exist in the SQLite3 database." "Pending background updates exist in the SQLite3 database."
" Please start Synapse again and wait until every update has finished" " Please start Synapse again and wait until every update has finished"
" before running this script.\n" " before running this script.\n"
) )
defer.returnValue(None) return
self.postgres_store = self.build_db_store( self.postgres_store = self.build_db_store(
self.hs_config.get_single_database() self.hs_config.get_single_database()
) )
yield self.run_background_updates_on_postgres() await self.run_background_updates_on_postgres()
self.progress.set_state("Creating port tables") self.progress.set_state("Creating port tables")
@ -547,22 +560,22 @@ class Porter(object):
) )
try: try:
yield self.postgres_store.db.runInteraction("alter_table", alter_table) await self.postgres_store.db.runInteraction("alter_table", alter_table)
except Exception: except Exception:
# On Error Resume Next # On Error Resume Next
pass pass
yield self.postgres_store.db.runInteraction( await self.postgres_store.db.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
) )
# Step 2. Get tables. # Step 2. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store.db.simple_select_onecol( sqlite_tables = await self.sqlite_store.db.simple_select_onecol(
table="sqlite_master", keyvalues={"type": "table"}, retcol="name" table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
) )
postgres_tables = yield self.postgres_store.db.simple_select_onecol( postgres_tables = await self.postgres_store.db.simple_select_onecol(
table="information_schema.tables", table="information_schema.tables",
keyvalues={}, keyvalues={},
retcol="distinct table_name", retcol="distinct table_name",
@ -573,28 +586,34 @@ class Porter(object):
# Step 3. Figure out what still needs copying # Step 3. Figure out what still needs copying
self.progress.set_state("Checking on port progress") self.progress.set_state("Checking on port progress")
setup_res = yield defer.gatherResults( setup_res = await make_deferred_yieldable(
[ defer.gatherResults(
self.setup_table(table) [
for table in tables run_in_background(self.setup_table, table)
if table not in ["schema_version", "applied_schema_deltas"] for table in tables
and not table.startswith("sqlite_") if table not in ["schema_version", "applied_schema_deltas"]
], and not table.startswith("sqlite_")
consumeErrors=True, ],
consumeErrors=True,
)
) )
# Step 4. Do the copying. # Step 4. Do the copying.
self.progress.set_state("Copying to postgres") self.progress.set_state("Copying to postgres")
yield defer.gatherResults( await make_deferred_yieldable(
[self.handle_table(*res) for res in setup_res], consumeErrors=True defer.gatherResults(
[run_in_background(self.handle_table, *res) for res in setup_res],
consumeErrors=True,
)
) )
# Step 5. Do final post-processing # Step 5. Do final post-processing
yield self._setup_state_group_id_seq() await self._setup_state_group_id_seq()
self.progress.done() self.progress.done()
except Exception: except Exception as e:
global end_error_exec_info global end_error_exec_info
end_error = e
end_error_exec_info = sys.exc_info() end_error_exec_info = sys.exc_info()
logger.exception("") logger.exception("")
finally: finally:
@ -634,8 +653,7 @@ class Porter(object):
return outrows return outrows
@defer.inlineCallbacks async def _setup_sent_transactions(self):
def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000 yesterday = int(time.time() * 1000) - 86400000
@ -656,7 +674,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday] return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
@ -669,7 +687,7 @@ class Porter(object):
txn, "sent_transactions", headers[1:], rows txn, "sent_transactions", headers[1:], rows
) )
yield self.postgres_store.execute(insert) await self.postgres_store.execute(insert)
else: else:
max_inserted_rowid = 0 max_inserted_rowid = 0
@ -686,10 +704,10 @@ class Porter(object):
else: else:
return 1 return 1
next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = await self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk) next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store.db.simple_insert( await self.postgres_store.db.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": "sent_transactions", "table_name": "sent_transactions",
@ -705,46 +723,49 @@ class Porter(object):
(size,) = txn.fetchone() (size,) = txn.fetchone()
return int(size) return int(size)
remaining_count = yield self.sqlite_store.execute(get_sent_table_size) remaining_count = await self.sqlite_store.execute(get_sent_table_size)
total_count = remaining_count + inserted_rows total_count = remaining_count + inserted_rows
defer.returnValue((next_chunk, inserted_rows, total_count)) return next_chunk, inserted_rows, total_count
@defer.inlineCallbacks async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): frows = await self.sqlite_store.execute_sql(
frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
) )
brows = yield self.sqlite_store.execute_sql( brows = await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
) )
defer.returnValue(frows[0][0] + brows[0][0]) return frows[0][0] + brows[0][0]
@defer.inlineCallbacks async def _get_already_ported_count(self, table):
def _get_already_ported_count(self, table): rows = await self.postgres_store.execute_sql(
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,) "SELECT count(*) FROM %s" % (table,)
) )
defer.returnValue(rows[0][0]) return rows[0][0]
@defer.inlineCallbacks async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): remaining, done = await make_deferred_yieldable(
remaining, done = yield defer.gatherResults( defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), run_in_background(
self._get_already_ported_count(table), self._get_remaining_count_to_port,
], table,
consumeErrors=True, forward_chunk,
backward_chunk,
),
run_in_background(self._get_already_ported_count, table),
],
)
) )
remaining = int(remaining) if remaining else 0 remaining = int(remaining) if remaining else 0
done = int(done) if done else 0 done = int(done) if done else 0
defer.returnValue((done, remaining + done)) return done, remaining + done
def _setup_state_group_id_seq(self): def _setup_state_group_id_seq(self):
def r(txn): def r(txn):
@ -1010,7 +1031,12 @@ if __name__ == "__main__":
hs_config=config, hs_config=config,
) )
reactor.callWhenRunning(porter.run) @defer.inlineCallbacks
def run():
with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run())
reactor.callWhenRunning(run)
reactor.run() reactor.run()
@ -1019,7 +1045,11 @@ if __name__ == "__main__":
else: else:
start() start()
if end_error_exec_info: if end_error:
exc_type, exc_value, exc_traceback = end_error_exec_info if end_error_exec_info:
traceback.print_exception(exc_type, exc_value, exc_traceback) exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)
sys.stderr.write(end_error)
sys.exit(5) sys.exit(5)