mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 08:56:04 -04:00
Make scripts/ and scripts-dev/ pass pyflakes (and the rest of the codebase on py3) (#4068)
This commit is contained in:
parent
81d4f51524
commit
e1728dfcbe
27 changed files with 511 additions and 518 deletions
|
@ -1,12 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
import bcrypt
|
||||
import getpass
|
||||
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds=12
|
||||
|
@ -52,4 +50,3 @@ if __name__ == "__main__":
|
|||
password = prompt_for_pass()
|
||||
|
||||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||
|
||||
|
|
|
@ -36,12 +36,9 @@ from __future__ import print_function
|
|||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||
|
||||
|
@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
|
|||
if not os.path.exists(original_file):
|
||||
logger.warn(
|
||||
"Original for %s/%s (%s) does not exist",
|
||||
origin_server, file_id, original_file,
|
||||
origin_server,
|
||||
file_id,
|
||||
original_file,
|
||||
)
|
||||
else:
|
||||
mkdir_and_move(
|
||||
original_file,
|
||||
dest_paths.remote_media_filepath(origin_server, file_id),
|
||||
original_file, dest_paths.remote_media_filepath(origin_server, file_id)
|
||||
)
|
||||
|
||||
# now look for thumbnails
|
||||
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
|
||||
origin_server, file_id,
|
||||
)
|
||||
original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id)
|
||||
if not os.path.exists(original_thumb_dir):
|
||||
return
|
||||
|
||||
mkdir_and_move(
|
||||
original_thumb_dir,
|
||||
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
|
||||
dest_paths.remote_media_thumbnail_dir(origin_server, file_id),
|
||||
)
|
||||
|
||||
|
||||
|
@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class = argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", action='store_true', help='enable debug logging')
|
||||
parser.add_argument(
|
||||
"src_repo",
|
||||
help="Path to source content repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dest_repo",
|
||||
help="Path to source content repo",
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("-v", action='store_true', help='enable debug logging')
|
||||
parser.add_argument("src_repo", help="Path to source content repo")
|
||||
parser.add_argument("dest_repo", help="Path to source content repo")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
logging.basicConfig(**logging_config)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
|
@ -22,19 +23,23 @@ import hmac
|
|||
import json
|
||||
import sys
|
||||
import urllib2
|
||||
|
||||
from six import input
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/r0/admin/register" % (server_location,),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
headers={'Content-Type': 'application/json'},
|
||||
)
|
||||
|
||||
try:
|
||||
if sys.version_info[:3] >= (2, 7, 9):
|
||||
# As of version 2.7.9, urllib2 now checks SSL certs
|
||||
import ssl
|
||||
|
||||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
else:
|
||||
f = urllib2.urlopen(req)
|
||||
|
@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F
|
|||
f.close()
|
||||
nonce = json.loads(body)["nonce"]
|
||||
except urllib2.HTTPError as e:
|
||||
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||
print("ERROR! Received %d %s" % (e.code, e.reason))
|
||||
if 400 <= e.code < 500:
|
||||
if e.info().type == "application/json":
|
||||
resp = json.load(e)
|
||||
if "error" in resp:
|
||||
print resp["error"]
|
||||
print(resp["error"])
|
||||
sys.exit(1)
|
||||
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
digestmod=hashlib.sha1,
|
||||
)
|
||||
mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1)
|
||||
|
||||
mac.update(nonce)
|
||||
mac.update("\x00")
|
||||
|
@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F
|
|||
|
||||
server_location = server_location.rstrip("/")
|
||||
|
||||
print "Sending registration request..."
|
||||
print("Sending registration request...")
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/r0/admin/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
headers={'Content-Type': 'application/json'},
|
||||
)
|
||||
try:
|
||||
if sys.version_info[:3] >= (2, 7, 9):
|
||||
# As of version 2.7.9, urllib2 now checks SSL certs
|
||||
import ssl
|
||||
|
||||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
else:
|
||||
f = urllib2.urlopen(req)
|
||||
f.read()
|
||||
f.close()
|
||||
print "Success."
|
||||
print("Success.")
|
||||
except urllib2.HTTPError as e:
|
||||
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||
print("ERROR! Received %d %s" % (e.code, e.reason))
|
||||
if 400 <= e.code < 500:
|
||||
if e.info().type == "application/json":
|
||||
resp = json.load(e)
|
||||
if "error" in resp:
|
||||
print resp["error"]
|
||||
print(resp["error"])
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin):
|
|||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
except:
|
||||
except Exception:
|
||||
default_user = None
|
||||
|
||||
if default_user:
|
||||
user = raw_input("New user localpart [%s]: " % (default_user,))
|
||||
user = input("New user localpart [%s]: " % (default_user,))
|
||||
if not user:
|
||||
user = default_user
|
||||
else:
|
||||
user = raw_input("New user localpart: ")
|
||||
user = input("New user localpart: ")
|
||||
|
||||
if not user:
|
||||
print "Invalid user name"
|
||||
print("Invalid user name")
|
||||
sys.exit(1)
|
||||
|
||||
if not password:
|
||||
password = getpass.getpass("Password: ")
|
||||
|
||||
if not password:
|
||||
print "Password cannot be blank."
|
||||
print("Password cannot be blank.")
|
||||
sys.exit(1)
|
||||
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
|
||||
if password != confirm_password:
|
||||
print "Passwords do not match"
|
||||
print("Passwords do not match")
|
||||
sys.exit(1)
|
||||
|
||||
if admin is None:
|
||||
admin = raw_input("Make admin [no]: ")
|
||||
admin = input("Make admin [no]: ")
|
||||
if admin in ("y", "yes", "true"):
|
||||
admin = True
|
||||
else:
|
||||
|
@ -146,42 +149,51 @@ def register_new_user(user, password, server_location, shared_secret, admin):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Used to register new users with a given home server when"
|
||||
" registration has been disabled. The home server must be"
|
||||
" configured with the 'registration_shared_secret' option"
|
||||
" set.",
|
||||
" registration has been disabled. The home server must be"
|
||||
" configured with the 'registration_shared_secret' option"
|
||||
" set."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--user",
|
||||
"-u",
|
||||
"--user",
|
||||
default=None,
|
||||
help="Local part of the new user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--password",
|
||||
"-p",
|
||||
"--password",
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
admin_group = parser.add_mutually_exclusive_group()
|
||||
admin_group.add_argument(
|
||||
"-a", "--admin",
|
||||
"-a",
|
||||
"--admin",
|
||||
action="store_true",
|
||||
help="Register new user as an admin. Will prompt if --no-admin is not set either.",
|
||||
help=(
|
||||
"Register new user as an admin. "
|
||||
"Will prompt if --no-admin is not set either."
|
||||
),
|
||||
)
|
||||
admin_group.add_argument(
|
||||
"--no-admin",
|
||||
action="store_true",
|
||||
help="Register new user as a regular user. Will prompt if --admin is not set either.",
|
||||
help=(
|
||||
"Register new user as a regular user. "
|
||||
"Will prompt if --admin is not set either."
|
||||
),
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-c", "--config",
|
||||
"-c",
|
||||
"--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in shared secret.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"-k", "--shared-secret",
|
||||
help="Shared secret as defined in server config file.",
|
||||
"-k", "--shared-secret", help="Shared secret as defined in server config file."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -189,7 +201,7 @@ if __name__ == "__main__":
|
|||
default="https://localhost:8448",
|
||||
nargs='?',
|
||||
help="URL to use to talk to the home server. Defaults to "
|
||||
" 'https://localhost:8448'.",
|
||||
" 'https://localhost:8448'.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -198,7 +210,7 @@ if __name__ == "__main__":
|
|||
config = yaml.safe_load(args.config)
|
||||
secret = config.get("registration_shared_secret", None)
|
||||
if not secret:
|
||||
print "No 'registration_shared_secret' defined in config."
|
||||
print("No 'registration_shared_secret' defined in config.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
secret = args.shared_secret
|
||||
|
|
|
@ -15,23 +15,23 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.enterprise import adbapi
|
||||
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from six import string_types
|
||||
|
||||
import yaml
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
|
@ -105,6 +105,7 @@ class Store(object):
|
|||
|
||||
*All* database interactions should go through this object.
|
||||
"""
|
||||
|
||||
def __init__(self, db_pool, engine):
|
||||
self.db_pool = db_pool
|
||||
self.database_engine = engine
|
||||
|
@ -135,7 +136,8 @@ class Store(object):
|
|||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, desc, self.database_engine, [], []),
|
||||
*args, **kwargs
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
if self.database_engine.is_deadlock(e):
|
||||
|
@ -158,22 +160,20 @@ class Store(object):
|
|||
def r(txn):
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
return self.runInteraction("execute_sql", r)
|
||||
|
||||
def insert_many_txn(self, txn, table, headers, rows):
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in headers),
|
||||
", ".join("%s" for _ in headers)
|
||||
", ".join("%s" for _ in headers),
|
||||
)
|
||||
|
||||
try:
|
||||
txn.executemany(sql, rows)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to insert: %s",
|
||||
table,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to insert: %s", table)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ class Porter(object):
|
|||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
forward_chunk = 1
|
||||
|
@ -221,10 +221,10 @@ class Porter(object):
|
|||
table, forward_chunk, backward_chunk
|
||||
)
|
||||
else:
|
||||
|
||||
def delete_all(txn):
|
||||
txn.execute(
|
||||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
|
||||
(table,)
|
||||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
|
||||
)
|
||||
txn.execute("TRUNCATE %s CASCADE" % (table,))
|
||||
|
||||
|
@ -232,11 +232,7 @@ class Porter(object):
|
|||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={
|
||||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
|
||||
)
|
||||
|
||||
forward_chunk = 1
|
||||
|
@ -251,12 +247,16 @@ class Porter(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_table(self, table, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
def handle_table(
|
||||
self, table, postgres_size, table_size, forward_chunk, backward_chunk
|
||||
):
|
||||
logger.info(
|
||||
"Table %s: %i/%i (rows %i-%i) already ported",
|
||||
table, postgres_size, table_size,
|
||||
backward_chunk+1, forward_chunk-1,
|
||||
table,
|
||||
postgres_size,
|
||||
table_size,
|
||||
backward_chunk + 1,
|
||||
forward_chunk - 1,
|
||||
)
|
||||
|
||||
if not table_size:
|
||||
|
@ -271,7 +271,9 @@ class Porter(object):
|
|||
return
|
||||
|
||||
if table in (
|
||||
"user_directory", "user_directory_search", "users_who_share_rooms",
|
||||
"user_directory",
|
||||
"user_directory_search",
|
||||
"users_who_share_rooms",
|
||||
"users_in_pubic_room",
|
||||
):
|
||||
# We don't port these tables, as they're a faff and we can regenreate
|
||||
|
@ -283,37 +285,35 @@ class Porter(object):
|
|||
# We need to make sure there is a single row, `(X, null), as that is
|
||||
# what synapse expects to be there.
|
||||
yield self.postgres_store._simple_insert(
|
||||
table=table,
|
||||
values={"stream_id": None},
|
||||
table=table, values={"stream_id": None}
|
||||
)
|
||||
self.progress.update(table, table_size) # Mark table as done
|
||||
return
|
||||
|
||||
forward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
|
||||
)
|
||||
|
||||
backward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
|
||||
)
|
||||
|
||||
do_forward = [True]
|
||||
do_backward = [True]
|
||||
|
||||
while True:
|
||||
|
||||
def r(txn):
|
||||
forward_rows = []
|
||||
backward_rows = []
|
||||
if do_forward[0]:
|
||||
txn.execute(forward_select, (forward_chunk, self.batch_size,))
|
||||
txn.execute(forward_select, (forward_chunk, self.batch_size))
|
||||
forward_rows = txn.fetchall()
|
||||
if not forward_rows:
|
||||
do_forward[0] = False
|
||||
|
||||
if do_backward[0]:
|
||||
txn.execute(backward_select, (backward_chunk, self.batch_size,))
|
||||
txn.execute(backward_select, (backward_chunk, self.batch_size))
|
||||
backward_rows = txn.fetchall()
|
||||
if not backward_rows:
|
||||
do_backward[0] = False
|
||||
|
@ -325,9 +325,7 @@ class Porter(object):
|
|||
|
||||
return headers, forward_rows, backward_rows
|
||||
|
||||
headers, frows, brows = yield self.sqlite_store.runInteraction(
|
||||
"select", r
|
||||
)
|
||||
headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
if frows or brows:
|
||||
if frows:
|
||||
|
@ -339,9 +337,7 @@ class Porter(object):
|
|||
rows = self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, table, headers[1:], rows
|
||||
)
|
||||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -362,8 +358,9 @@ class Porter(object):
|
|||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_search_table(self, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
def handle_search_table(
|
||||
self, postgres_size, table_size, forward_chunk, backward_chunk
|
||||
):
|
||||
select = (
|
||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||
" FROM event_search as es"
|
||||
|
@ -373,8 +370,9 @@ class Porter(object):
|
|||
)
|
||||
|
||||
while True:
|
||||
|
||||
def r(txn):
|
||||
txn.execute(select, (forward_chunk, self.batch_size,))
|
||||
txn.execute(select, (forward_chunk, self.batch_size))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
|
@ -402,18 +400,21 @@ class Porter(object):
|
|||
else:
|
||||
rows_dict.append(d)
|
||||
|
||||
txn.executemany(sql, [
|
||||
(
|
||||
row["event_id"],
|
||||
row["room_id"],
|
||||
row["key"],
|
||||
row["sender"],
|
||||
row["value"],
|
||||
row["origin_server_ts"],
|
||||
row["stream_ordering"],
|
||||
)
|
||||
for row in rows_dict
|
||||
])
|
||||
txn.executemany(
|
||||
sql,
|
||||
[
|
||||
(
|
||||
row["event_id"],
|
||||
row["room_id"],
|
||||
row["key"],
|
||||
row["sender"],
|
||||
row["value"],
|
||||
row["origin_server_ts"],
|
||||
row["stream_ordering"],
|
||||
)
|
||||
for row in rows_dict
|
||||
],
|
||||
)
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -437,7 +438,8 @@ class Porter(object):
|
|||
def setup_db(self, db_config, database_engine):
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
k: v for k, v in db_config.get("args", {}).items()
|
||||
k: v
|
||||
for k, v in db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
)
|
||||
|
@ -450,13 +452,11 @@ class Porter(object):
|
|||
def run(self):
|
||||
try:
|
||||
sqlite_db_pool = adbapi.ConnectionPool(
|
||||
self.sqlite_config["name"],
|
||||
**self.sqlite_config["args"]
|
||||
self.sqlite_config["name"], **self.sqlite_config["args"]
|
||||
)
|
||||
|
||||
postgres_db_pool = adbapi.ConnectionPool(
|
||||
self.postgres_config["name"],
|
||||
**self.postgres_config["args"]
|
||||
self.postgres_config["name"], **self.postgres_config["args"]
|
||||
)
|
||||
|
||||
sqlite_engine = create_engine(sqlite_config)
|
||||
|
@ -465,9 +465,7 @@ class Porter(object):
|
|||
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
|
||||
self.postgres_store = Store(postgres_db_pool, postgres_engine)
|
||||
|
||||
yield self.postgres_store.execute(
|
||||
postgres_engine.check_database
|
||||
)
|
||||
yield self.postgres_store.execute(postgres_engine.check_database)
|
||||
|
||||
# Step 1. Set up databases.
|
||||
self.progress.set_state("Preparing SQLite3")
|
||||
|
@ -477,6 +475,7 @@ class Porter(object):
|
|||
self.setup_db(postgres_config, postgres_engine)
|
||||
|
||||
self.progress.set_state("Creating port tables")
|
||||
|
||||
def create_port_table(txn):
|
||||
txn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
|
||||
|
@ -501,9 +500,7 @@ class Porter(object):
|
|||
)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"alter_table", alter_table
|
||||
)
|
||||
yield self.postgres_store.runInteraction("alter_table", alter_table)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
@ -514,11 +511,7 @@ class Porter(object):
|
|||
# Step 2. Get tables.
|
||||
self.progress.set_state("Fetching tables")
|
||||
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
|
||||
table="sqlite_master",
|
||||
keyvalues={
|
||||
"type": "table",
|
||||
},
|
||||
retcol="name",
|
||||
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
|
||||
)
|
||||
|
||||
postgres_tables = yield self.postgres_store._simple_select_onecol(
|
||||
|
@ -545,18 +538,14 @@ class Porter(object):
|
|||
# Step 4. Do the copying.
|
||||
self.progress.set_state("Copying to postgres")
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.handle_table(*res)
|
||||
for res in setup_res
|
||||
],
|
||||
consumeErrors=True,
|
||||
[self.handle_table(*res) for res in setup_res], consumeErrors=True
|
||||
)
|
||||
|
||||
# Step 5. Do final post-processing
|
||||
yield self._setup_state_group_id_seq()
|
||||
|
||||
self.progress.done()
|
||||
except:
|
||||
except Exception:
|
||||
global end_error_exec_info
|
||||
end_error_exec_info = sys.exc_info()
|
||||
logger.exception("")
|
||||
|
@ -566,9 +555,7 @@ class Porter(object):
|
|||
def _convert_rows(self, table, headers, rows):
|
||||
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
|
||||
|
||||
bool_cols = [
|
||||
i for i, h in enumerate(headers) if h in bool_col_names
|
||||
]
|
||||
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
|
||||
|
||||
class BadValueException(Exception):
|
||||
pass
|
||||
|
@ -577,18 +564,21 @@ class Porter(object):
|
|||
if j in bool_cols:
|
||||
return bool(col)
|
||||
elif isinstance(col, string_types) and "\0" in col:
|
||||
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
|
||||
raise BadValueException();
|
||||
logger.warn(
|
||||
"DROPPING ROW: NUL value in table %s col %s: %r",
|
||||
table,
|
||||
headers[j],
|
||||
col,
|
||||
)
|
||||
raise BadValueException()
|
||||
return col
|
||||
|
||||
outrows = []
|
||||
for i, row in enumerate(rows):
|
||||
try:
|
||||
outrows.append(tuple(
|
||||
conv(j, col)
|
||||
for j, col in enumerate(row)
|
||||
if j > 0
|
||||
))
|
||||
outrows.append(
|
||||
tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
|
||||
)
|
||||
except BadValueException:
|
||||
pass
|
||||
|
||||
|
@ -616,9 +606,7 @@ class Porter(object):
|
|||
|
||||
return headers, [r for r in rows if r[ts_ind] < yesterday]
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction(
|
||||
"select", r,
|
||||
)
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
rows = self._convert_rows("sent_transactions", headers, rows)
|
||||
|
||||
|
@ -639,7 +627,7 @@ class Porter(object):
|
|||
txn.execute(
|
||||
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
|
||||
" ORDER BY rowid ASC LIMIT 1",
|
||||
(yesterday,)
|
||||
(yesterday,),
|
||||
)
|
||||
|
||||
rows = txn.fetchall()
|
||||
|
@ -657,21 +645,17 @@ class Porter(object):
|
|||
"table_name": "sent_transactions",
|
||||
"forward_rowid": next_chunk,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_sent_table_size(txn):
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM sent_transactions"
|
||||
" WHERE ts >= ?",
|
||||
(yesterday,)
|
||||
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
|
||||
)
|
||||
size, = txn.fetchone()
|
||||
return int(size)
|
||||
|
||||
remaining_count = yield self.sqlite_store.execute(
|
||||
get_sent_table_size
|
||||
)
|
||||
remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
|
||||
|
||||
total_count = remaining_count + inserted_rows
|
||||
|
||||
|
@ -680,13 +664,11 @@ class Porter(object):
|
|||
@defer.inlineCallbacks
|
||||
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
frows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
||||
forward_chunk,
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
|
||||
)
|
||||
|
||||
brows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
|
||||
backward_chunk,
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
|
||||
)
|
||||
|
||||
defer.returnValue(frows[0][0] + brows[0][0])
|
||||
|
@ -694,7 +676,7 @@ class Porter(object):
|
|||
@defer.inlineCallbacks
|
||||
def _get_already_ported_count(self, table):
|
||||
rows = yield self.postgres_store.execute_sql(
|
||||
"SELECT count(*) FROM %s" % (table,),
|
||||
"SELECT count(*) FROM %s" % (table,)
|
||||
)
|
||||
|
||||
defer.returnValue(rows[0][0])
|
||||
|
@ -717,22 +699,21 @@ class Porter(object):
|
|||
def _setup_state_group_id_seq(self):
|
||||
def r(txn):
|
||||
txn.execute("SELECT MAX(id) FROM state_groups")
|
||||
next_id = txn.fetchone()[0]+1
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
|
||||
(next_id,),
|
||||
)
|
||||
next_id = txn.fetchone()[0] + 1
|
||||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
|
||||
|
||||
return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
|
||||
|
||||
|
||||
##############################################
|
||||
###### The following is simply UI stuff ######
|
||||
# The following is simply UI stuff
|
||||
##############################################
|
||||
|
||||
|
||||
class Progress(object):
|
||||
"""Used to report progress of the port
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tables = {}
|
||||
|
||||
|
@ -758,6 +739,7 @@ class Progress(object):
|
|||
class CursesProgress(Progress):
|
||||
"""Reports progress to a curses window
|
||||
"""
|
||||
|
||||
def __init__(self, stdscr):
|
||||
self.stdscr = stdscr
|
||||
|
||||
|
@ -801,7 +783,7 @@ class CursesProgress(Progress):
|
|||
duration = int(now) - int(self.start_time)
|
||||
|
||||
minutes, seconds = divmod(duration, 60)
|
||||
duration_str = '%02dm %02ds' % (minutes, seconds,)
|
||||
duration_str = '%02dm %02ds' % (minutes, seconds)
|
||||
|
||||
if self.finished:
|
||||
status = "Time spent: %s (Done!)" % (duration_str,)
|
||||
|
@ -814,16 +796,12 @@ class CursesProgress(Progress):
|
|||
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
|
||||
else:
|
||||
est_remaining_str = "Unknown"
|
||||
status = (
|
||||
"Time spent: %s (est. remaining: %s)"
|
||||
% (duration_str, est_remaining_str,)
|
||||
status = "Time spent: %s (est. remaining: %s)" % (
|
||||
duration_str,
|
||||
est_remaining_str,
|
||||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
0, 0,
|
||||
status,
|
||||
curses.A_BOLD,
|
||||
)
|
||||
self.stdscr.addstr(0, 0, status, curses.A_BOLD)
|
||||
|
||||
max_len = max([len(t) for t in self.tables.keys()])
|
||||
|
||||
|
@ -831,9 +809,7 @@ class CursesProgress(Progress):
|
|||
middle_space = 1
|
||||
|
||||
items = self.tables.items()
|
||||
items.sort(
|
||||
key=lambda i: (i[1]["perc"], i[0]),
|
||||
)
|
||||
items.sort(key=lambda i: (i[1]["perc"], i[0]))
|
||||
|
||||
for i, (table, data) in enumerate(items):
|
||||
if i + 2 >= rows:
|
||||
|
@ -844,9 +820,7 @@ class CursesProgress(Progress):
|
|||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i + 2, left_margin + max_len - len(table),
|
||||
table,
|
||||
curses.A_BOLD | color,
|
||||
i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
|
||||
)
|
||||
|
||||
size = 20
|
||||
|
@ -857,15 +831,13 @@ class CursesProgress(Progress):
|
|||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i + 2, left_margin + max_len + middle_space,
|
||||
i + 2,
|
||||
left_margin + max_len + middle_space,
|
||||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
||||
)
|
||||
|
||||
if self.finished:
|
||||
self.stdscr.addstr(
|
||||
rows - 1, 0,
|
||||
"Press any key to exit...",
|
||||
)
|
||||
self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
|
||||
|
||||
self.stdscr.refresh()
|
||||
self.last_update = time.time()
|
||||
|
@ -877,29 +849,25 @@ class CursesProgress(Progress):
|
|||
|
||||
def set_state(self, state):
|
||||
self.stdscr.clear()
|
||||
self.stdscr.addstr(
|
||||
0, 0,
|
||||
state + "...",
|
||||
curses.A_BOLD,
|
||||
)
|
||||
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
|
||||
self.stdscr.refresh()
|
||||
|
||||
|
||||
class TerminalProgress(Progress):
|
||||
"""Just prints progress to the terminal
|
||||
"""
|
||||
|
||||
def update(self, table, num_done):
|
||||
super(TerminalProgress, self).update(table, num_done)
|
||||
|
||||
data = self.tables[table]
|
||||
|
||||
print "%s: %d%% (%d/%d)" % (
|
||||
table, data["perc"],
|
||||
data["num_done"], data["total"],
|
||||
print(
|
||||
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
|
||||
)
|
||||
|
||||
def set_state(self, state):
|
||||
print state + "..."
|
||||
print(state + "...")
|
||||
|
||||
|
||||
##############################################
|
||||
|
@ -909,34 +877,38 @@ class TerminalProgress(Progress):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A script to port an existing synapse SQLite database to"
|
||||
" a new PostgreSQL database."
|
||||
" a new PostgreSQL database."
|
||||
)
|
||||
parser.add_argument("-v", action='store_true')
|
||||
parser.add_argument(
|
||||
"--sqlite-database", required=True,
|
||||
"--sqlite-database",
|
||||
required=True,
|
||||
help="The snapshot of the SQLite database file. This must not be"
|
||||
" currently used by a running synapse server"
|
||||
" currently used by a running synapse server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--postgres-config", type=argparse.FileType('r'), required=True,
|
||||
help="The database config file for the PostgreSQL database"
|
||||
"--postgres-config",
|
||||
type=argparse.FileType('r'),
|
||||
required=True,
|
||||
help="The database config file for the PostgreSQL database",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curses", action='store_true',
|
||||
help="display a curses based progress UI"
|
||||
"--curses", action='store_true', help="display a curses based progress UI"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=1000,
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The number of rows to select from the SQLite table each"
|
||||
" iteration [default=1000]",
|
||||
" iteration [default=1000]",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
|
||||
if args.curses:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue