Make scripts/ and scripts-dev/ pass pyflakes (and the rest of the codebase on py3) (#4068)

This commit is contained in:
Amber Brown 2018-10-20 11:16:55 +11:00 committed by GitHub
parent 81d4f51524
commit e1728dfcbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 511 additions and 518 deletions

View file

@ -15,23 +15,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse
import curses
import logging
import sys
import time
import traceback
import yaml
from six import string_types
import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("synapse_port_db")
@ -105,6 +105,7 @@ class Store(object):
*All* database interactions should go through this object.
"""
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
@ -135,7 +136,8 @@ class Store(object):
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, [], []),
*args, **kwargs
*args,
**kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
@ -158,22 +160,20 @@ class Store(object):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
", ".join("%s" for _ in headers),
)
try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
except Exception:
logger.exception("Failed to insert: %s", table)
raise
@ -206,7 +206,7 @@ class Porter(object):
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
},
)
forward_chunk = 1
@ -221,10 +221,10 @@ class Porter(object):
table, forward_chunk, backward_chunk
)
else:
def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))
@ -232,11 +232,7 @@ class Porter(object):
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
)
forward_chunk = 1
@ -251,12 +247,16 @@ class Porter(object):
)
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table, postgres_size, table_size,
backward_chunk+1, forward_chunk-1,
table,
postgres_size,
table_size,
backward_chunk + 1,
forward_chunk - 1,
)
if not table_size:
@ -271,7 +271,9 @@ class Porter(object):
return
if table in (
"user_directory", "user_directory_search", "users_who_share_rooms",
"user_directory",
"user_directory_search",
"users_who_share_rooms",
"users_in_pubic_room",
):
# We don't port these tables, as they're a faff and we can regenreate
@ -283,37 +285,35 @@ class Porter(object):
# We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there.
yield self.postgres_store._simple_insert(
table=table,
values={"stream_id": None},
table=table, values={"stream_id": None}
)
self.progress.update(table, table_size) # Mark table as done
return
forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
)
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
)
do_forward = [True]
do_backward = [True]
while True:
def r(txn):
forward_rows = []
backward_rows = []
if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
txn.execute(forward_select, (forward_chunk, self.batch_size))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
txn.execute(backward_select, (backward_chunk, self.batch_size))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
@ -325,9 +325,7 @@ class Porter(object):
return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
if frows or brows:
if frows:
@ -339,9 +337,7 @@ class Porter(object):
rows = self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store._simple_update_one_txn(
txn,
@ -362,8 +358,9 @@ class Porter(object):
return
@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk
):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@ -373,8 +370,9 @@ class Porter(object):
)
while True:
def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,))
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@ -402,18 +400,21 @@ class Porter(object):
else:
rows_dict.append(d)
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
])
txn.executemany(
sql,
[
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
],
)
self.postgres_store._simple_update_one_txn(
txn,
@ -437,7 +438,8 @@ class Porter(object):
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
k: v
for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
@ -450,13 +452,11 @@ class Porter(object):
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
self.sqlite_config["name"], **self.sqlite_config["args"]
)
postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
self.postgres_config["name"], **self.postgres_config["args"]
)
sqlite_engine = create_engine(sqlite_config)
@ -465,9 +465,7 @@ class Porter(object):
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
yield self.postgres_store.execute(
postgres_engine.check_database
)
yield self.postgres_store.execute(postgres_engine.check_database)
# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
@ -477,6 +475,7 @@ class Porter(object):
self.setup_db(postgres_config, postgres_engine)
self.progress.set_state("Creating port tables")
def create_port_table(txn):
txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
@ -501,9 +500,7 @@ class Porter(object):
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
yield self.postgres_store.runInteraction("alter_table", alter_table)
except Exception as e:
pass
@ -514,11 +511,7 @@ class Porter(object):
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
@ -545,18 +538,14 @@ class Porter(object):
# Step 4. Do the copying.
self.progress.set_state("Copying to postgres")
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
[self.handle_table(*res) for res in setup_res], consumeErrors=True
)
# Step 5. Do final post-processing
yield self._setup_state_group_id_seq()
self.progress.done()
except:
except Exception:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
@ -566,9 +555,7 @@ class Porter(object):
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
class BadValueException(Exception):
pass
@ -577,18 +564,21 @@ class Porter(object):
if j in bool_cols:
return bool(col)
elif isinstance(col, string_types) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
raise BadValueException();
logger.warn(
"DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
col,
)
raise BadValueException()
return col
outrows = []
for i, row in enumerate(rows):
try:
outrows.append(tuple(
conv(j, col)
for j, col in enumerate(row)
if j > 0
))
outrows.append(
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
)
except BadValueException:
pass
@ -616,9 +606,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
headers, rows = yield self.sqlite_store.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows)
@ -639,7 +627,7 @@ class Porter(object):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
(yesterday,),
)
rows = txn.fetchall()
@ -657,21 +645,17 @@ class Porter(object):
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
},
)
def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
size, = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
total_count = remaining_count + inserted_rows
@ -680,13 +664,11 @@ class Porter(object):
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk,
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
)
brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
)
defer.returnValue(frows[0][0] + brows[0][0])
@ -694,7 +676,7 @@ class Porter(object):
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
"SELECT count(*) FROM %s" % (table,)
)
defer.returnValue(rows[0][0])
@ -717,22 +699,21 @@ class Porter(object):
def _setup_state_group_id_seq(self):
def r(txn):
txn.execute("SELECT MAX(id) FROM state_groups")
next_id = txn.fetchone()[0]+1
txn.execute(
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
(next_id,),
)
next_id = txn.fetchone()[0] + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
##############################################
###### The following is simply UI stuff ######
# The following is simply UI stuff
##############################################
class Progress(object):
"""Used to report progress of the port
"""
def __init__(self):
self.tables = {}
@ -758,6 +739,7 @@ class Progress(object):
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
def __init__(self, stdscr):
self.stdscr = stdscr
@ -801,7 +783,7 @@ class CursesProgress(Progress):
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
duration_str = '%02dm %02ds' % (minutes, seconds)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
@ -814,16 +796,12 @@ class CursesProgress(Progress):
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
status = "Time spent: %s (est. remaining: %s)" % (
duration_str,
est_remaining_str,
)
self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
self.stdscr.addstr(0, 0, status, curses.A_BOLD)
max_len = max([len(t) for t in self.tables.keys()])
@ -831,9 +809,7 @@ class CursesProgress(Progress):
middle_space = 1
items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
items.sort(key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@ -844,9 +820,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i + 2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
)
size = 20
@ -857,15 +831,13 @@ class CursesProgress(Progress):
)
self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space,
i + 2,
left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows - 1, 0,
"Press any key to exit...",
)
self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
self.stdscr.refresh()
self.last_update = time.time()
@ -877,29 +849,25 @@ class CursesProgress(Progress):
def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh()
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
print(
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
)
def set_state(self, state):
print state + "..."
print(state + "...")
##############################################
@ -909,34 +877,38 @@ class TerminalProgress(Progress):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument(
"--sqlite-database", required=True,
"--sqlite-database",
required=True,
help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server"
" currently used by a running synapse server",
)
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True,
help="The database config file for the PostgreSQL database"
"--postgres-config",
type=argparse.FileType('r'),
required=True,
help="The database config file for the PostgreSQL database",
)
parser.add_argument(
"--curses", action='store_true',
help="display a curses based progress UI"
"--curses", action='store_true', help="display a curses based progress UI"
)
parser.add_argument(
"--batch-size", type=int, default=1000,
"--batch-size",
type=int,
default=1000,
help="The number of rows to select from the SQLite table each"
" iteration [default=1000]",
" iteration [default=1000]",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
if args.curses: