Make synapse._scripts pass typechecks (#12421)

This commit is contained in:
David Robertson 2022-04-08 15:00:12 +01:00 committed by GitHub
parent dd5cc37aa4
commit 0cd182f296
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 43 deletions

View file

@ -21,12 +21,13 @@ import logging
import sys
import time
import traceback
from typing import Dict, Iterable, Optional, Set
from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import Clock
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
@ -159,12 +164,14 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None # type: Optional[str]
end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None
end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None
class Store(
@ -236,9 +243,12 @@ class MockHomeserver:
return "master"
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class Porter:
def __init__(self, sqlite_config, progress, batch_size, hs_config):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
async def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
@ -323,7 +333,7 @@ class Porter(object):
"""
txn.execute(sql)
results = {}
results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@ -540,7 +550,8 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()
return store
@ -724,7 +735,9 @@ class Porter(object):
except Exception as e:
global end_error_exec_info
end_error = str(e)
end_error_exec_info = sys.exc_info()
# Type safety: we're in an exception handler, so the exc_info() tuple
# will not be (None, None, None).
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.last_update = 0.0
self.finished = False
@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
left_margin = 5
middle_space = 1
items = self.tables.items()
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@ -1179,15 +1191,11 @@ def main():
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
filename="port-synapse.log" if args.curses else None,
)
sqlite_config = {
"name": "sqlite3",
@ -1218,6 +1226,7 @@ def main():
config.parse_config_dict(hs_config, "", "")
def start(stdscr=None):
progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else: