Make synapse._scripts pass typechecks (#12421)

This commit is contained in:
David Robertson 2022-04-08 15:00:12 +01:00 committed by GitHub
parent dd5cc37aa4
commit 0cd182f296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 43 deletions

1
changelog.d/12421.misc Normal file
View File

@ -0,0 +1 @@
Make `synapse._scripts` pass type checks.

View File

@ -28,11 +28,6 @@ exclude = (?x)
|scripts-dev/federation_client.py |scripts-dev/federation_client.py
|scripts-dev/release.py |scripts-dev/release.py
|synapse/_scripts/export_signing_key.py
|synapse/_scripts/move_remote_media_to_new_store.py
|synapse/_scripts/synapse_port_db.py
|synapse/_scripts/update_synapse_database.py
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py

View File

@ -17,8 +17,8 @@ import sys
import time import time
from typing import Optional from typing import Optional
import nacl.signing
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None): def exit(status: int = 0, message: Optional[str] = None):
@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
sys.exit(status) sys.exit(status)
def format_plain(public_key: nacl.signing.VerifyKey): def format_plain(public_key: VerifyKey):
print( print(
"%s:%s %s" "%s:%s %s"
% ( % (
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
) )
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): def format_for_config(public_key: VerifyKey, expiry_ts: int):
print( print(
' "%s:%s": { key: "%s", expired_ts: %i }' ' "%s:%s": { key: "%s", expired_ts: %i }'
% ( % (

View File

@ -109,10 +109,9 @@ if __name__ == "__main__":
parser.add_argument("dest_repo", help="Path to source content repo") parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging.basicConfig(
"level": logging.DEBUG if args.v else logging.INFO, level=logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} )
logging.basicConfig(**logging_config)
main(args.src_repo, args.dest_repo) main(args.src_repo, args.dest_repo)

View File

@ -21,12 +21,13 @@ import logging
import sys import sys
import time import time
import traceback import traceback
from typing import Dict, Iterable, Optional, Set from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
import yaml import yaml
from matrix_common.versionstring import get_distribution_version_string from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import Clock from synapse.util import Clock
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db") logger = logging.getLogger("synapse_port_db")
@ -159,12 +164,14 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to # Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes. # handle errors and return codes.
end_error = None # type: Optional[str] end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script # The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but # will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run # not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace. # function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None
class Store( class Store(
@ -236,9 +243,12 @@ class MockHomeserver:
return "master" return "master"
class Porter(object): class Porter:
def __init__(self, **kwargs): def __init__(self, sqlite_config, progress, batch_size, hs_config):
self.__dict__.update(kwargs) self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
async def setup_table(self, table): async def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
@ -323,7 +333,7 @@ class Porter(object):
""" """
txn.execute(sql) txn.execute(sql)
results = {} results: Dict[str, Set[str]] = {}
for table, foreign_table in txn: for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table) results.setdefault(table, set()).add(foreign_table)
return results return results
@ -540,7 +550,8 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version db_conn, allow_outdated_version=allow_outdated_version
) )
prepare_database(db_conn, engine, config=self.hs_config) prepare_database(db_conn, engine, config=self.hs_config)
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit() db_conn.commit()
return store return store
@ -724,7 +735,9 @@ class Porter(object):
except Exception as e: except Exception as e:
global end_error_exec_info global end_error_exec_info
end_error = str(e) end_error = str(e)
end_error_exec_info = sys.exc_info() # Type safety: we're in an exception handler, so the exc_info() tuple
# will not be (None, None, None).
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("") logger.exception("")
finally: finally:
reactor.stop() reactor.stop()
@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1) curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1) curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0 self.last_update = 0.0
self.finished = False self.finished = False
@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
left_margin = 5 left_margin = 5
middle_space = 1 middle_space = 1
items = self.tables.items() items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items): for i, (table, data) in enumerate(items):
if i + 2 >= rows: if i + 2 >= rows:
@ -1179,15 +1191,11 @@ def main():
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging.basicConfig(
"level": logging.DEBUG if args.v else logging.INFO, level=logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} filename="port-synapse.log" if args.curses else None,
)
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
sqlite_config = { sqlite_config = {
"name": "sqlite3", "name": "sqlite3",
@ -1218,6 +1226,7 @@ def main():
config.parse_config_dict(hs_config, "", "") config.parse_config_dict(hs_config, "", "")
def start(stdscr=None): def start(stdscr=None):
progress: Progress
if stdscr: if stdscr:
progress = CursesProgress(stdscr) progress = CursesProgress(stdscr)
else: else:

View File

@ -16,22 +16,27 @@
import argparse import argparse
import logging import logging
import sys import sys
from typing import cast
import yaml import yaml
from matrix_common.versionstring import get_distribution_version_string from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor as reactor_
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.types import ISynapseReactor
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("update_database") logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer): class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__( super(MockHomeserver, self).__init__(
@ -85,12 +90,10 @@ def main():
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging.basicConfig(
"level": logging.DEBUG if args.v else logging.INFO, level=logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} )
logging.basicConfig(**logging_config)
# Load, process and sanity-check the config. # Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.database_config) hs_config = yaml.safe_load(args.database_config)