diff --git a/changelog.d/12422.misc b/changelog.d/12422.misc new file mode 100644 index 000000000..3a7cbc34e --- /dev/null +++ b/changelog.d/12422.misc @@ -0,0 +1 @@ +Make `synapse._scripts` pass type checks. diff --git a/mypy.ini b/mypy.ini index c11386b89..4ccea6fa5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -93,6 +93,9 @@ exclude = (?x) |tests/utils.py )$ +[mypy-synapse._scripts.*] +disallow_untyped_defs = True + [mypy-synapse.api.*] disallow_untyped_defs = True diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py index 66481533e..12c890bdb 100755 --- a/synapse/_scripts/export_signing_key.py +++ b/synapse/_scripts/export_signing_key.py @@ -15,19 +15,19 @@ import argparse import sys import time -from typing import Optional +from typing import NoReturn, Optional from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys from signedjson.types import VerifyKey -def exit(status: int = 0, message: Optional[str] = None): +def exit(status: int = 0, message: Optional[str] = None) -> NoReturn: if message: print(message, file=sys.stderr) sys.exit(status) -def format_plain(public_key: VerifyKey): +def format_plain(public_key: VerifyKey) -> None: print( "%s:%s %s" % ( @@ -38,7 +38,7 @@ def format_plain(public_key: VerifyKey): ) -def format_for_config(public_key: VerifyKey, expiry_ts: int): +def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None: print( ' "%s:%s": { key: "%s", expired_ts: %i }' % ( @@ -50,7 +50,7 @@ def format_for_config(public_key: VerifyKey, expiry_ts: int): ) -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( @@ -94,7 +94,6 @@ def main(): message="Error reading key from file %s: %s %s" % (file.name, type(e), e), ) - res = [] for key in res: formatter(get_verify_key(key)) diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py index 75fce20b1..08eb8ef11 100755 --- a/synapse/_scripts/generate_config.py +++ b/synapse/_scripts/generate_config.py @@ -7,7 +7,7 @@ import sys from synapse.config.homeserver import HomeServerConfig -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--config-dir", diff --git a/synapse/_scripts/generate_log_config.py b/synapse/_scripts/generate_log_config.py index 82fc76314..7ae08ec0e 100755 --- a/synapse/_scripts/generate_log_config.py +++ b/synapse/_scripts/generate_log_config.py @@ -20,7 +20,7 @@ import sys from synapse.config.logger import DEFAULT_LOG_CONFIG -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/_scripts/generate_signing_key.py b/synapse/_scripts/generate_signing_key.py index bc26d25bf..3f8f5da75 100755 --- a/synapse/_scripts/generate_signing_key.py +++ b/synapse/_scripts/generate_signing_key.py @@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys from synapse.util.stringutils import random_string -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py index 708640c7d..3aa29de5b 100755 --- a/synapse/_scripts/hash_password.py +++ b/synapse/_scripts/hash_password.py @@ -9,7 +9,7 @@ import bcrypt import yaml -def prompt_for_pass(): +def prompt_for_pass() -> str: password = getpass.getpass("Password: ") if not password: @@ -23,7 +23,7 @@ def prompt_for_pass(): return password -def main(): +def main() -> None: bcrypt_rounds = 12 password_pepper = "" diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index f53bf790a..819afaaca 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths logger = logging.getLogger() -def main(src_repo, dest_repo): +def main(src_repo: str, dest_repo: str) -> None: src_paths = MediaFilePaths(src_repo) dest_paths = MediaFilePaths(dest_repo) for line in sys.stdin: @@ -55,14 +55,19 @@ def main(src_repo, dest_repo): move_media(parts[0], parts[1], src_paths, dest_paths) -def move_media(origin_server, file_id, src_paths, dest_paths): +def move_media( + origin_server: str, + file_id: str, + src_paths: MediaFilePaths, + dest_paths: MediaFilePaths, +) -> None: """Move the given file, and any thumbnails, to the dest repo Args: - origin_server (str): - file_id (str): - src_paths (MediaFilePaths): - dest_paths (MediaFilePaths): + origin_server: + file_id: + src_paths: + dest_paths: """ logger.info("%s/%s", origin_server, file_id) @@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): ) -def mkdir_and_move(original_file, dest_file): +def mkdir_and_move(original_file: str, dest_file: str) -> None: dirname = os.path.dirname(dest_file) if not os.path.exists(dirname): logger.debug("mkdir %s", dirname) diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 4ffe6a1ef..092601f53 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -22,7 +22,7 @@ import logging import sys from typing import Callable, Optional -import requests as _requests +import requests import yaml @@ -33,7 +33,6 @@ def request_registration( shared_secret: str, admin: bool = False, user_type: Optional[str] = None, - requests=_requests, _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, ) -> None: diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 123eaae5c..12ff79f6e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -22,10 +22,26 @@ import sys import time import traceback from types import TracebackType -from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterable, + List, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + cast, +) import yaml from matrix_common.versionstring import get_distribution_version_string +from typing_extensions import TypedDict from twisted.internet import defer, reactor as reactor_ @@ -36,7 +52,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore @@ -173,6 +189,8 @@ end_error_exec_info: Optional[ Tuple[Type[BaseException], BaseException, TracebackType] ] = None +R = TypeVar("R") + class Store( ClientIpBackgroundUpdateStore, @@ -195,17 +213,19 @@ class Store( PresenceBackgroundUpdateStore, GroupServerWorkerStore, ): - def execute(self, f, *args, **kwargs): + def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) - def execute_sql(self, sql, *args): - def r(txn): + def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]: + def r(txn: LoggingTransaction) -> List[Tuple]: txn.execute(sql, args) return txn.fetchall() return self.db_pool.runInteraction("execute_sql", r) - def insert_many_txn(self, txn, table, headers, rows): + def insert_many_txn( + self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple] + ) -> None: sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in headers), @@ -218,14 +238,15 @@ class Store( logger.exception("Failed to insert: %s", table) raise - def set_room_is_public(self, room_id, is_public): + # Note: the parent method is an `async def`. + def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn: raise Exception( "Attempt to set room_is_public during port_db: database not empty?" ) class MockHomeserver: - def __init__(self, config): + def __init__(self, config: HomeServerConfig): self.clock = Clock(reactor) self.config = config self.hostname = config.server.server_name @@ -233,24 +254,30 @@ class MockHomeserver: "matrix-synapse" ) - def get_clock(self): + def get_clock(self) -> Clock: return self.clock - def get_reactor(self): + def get_reactor(self) -> ISynapseReactor: return reactor - def get_instance_name(self): + def get_instance_name(self) -> str: return "master" class Porter: - def __init__(self, sqlite_config, progress, batch_size, hs_config): + def __init__( + self, + sqlite_config: Dict[str, Any], + progress: "Progress", + batch_size: int, + hs_config: HomeServerConfig, + ): self.sqlite_config = sqlite_config self.progress = progress self.batch_size = batch_size self.hs_config = hs_config - async def setup_table(self, table): + async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]: if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. row = await self.postgres_store.db_pool.simple_select_one( @@ -292,7 +319,7 @@ class Porter: ) else: - def delete_all(txn): + def delete_all(txn: LoggingTransaction) -> None: txn.execute( "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) ) @@ -317,7 +344,7 @@ class Porter: async def get_table_constraints(self) -> Dict[str, Set[str]]: """Returns a map of tables that have foreign key constraints to tables they depend on.""" - def _get_constraints(txn): + def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]: # We can pull the information about foreign key constraints out from # the postgres schema tables. sql = """ @@ -343,8 +370,13 @@ class Porter: ) async def handle_table( - self, table, postgres_size, table_size, forward_chunk, backward_chunk - ): + self, + table: str, + postgres_size: int, + table_size: int, + forward_chunk: int, + backward_chunk: int, + ) -> None: logger.info( "Table %s: %i/%i (rows %i-%i) already ported", table, @@ -391,7 +423,9 @@ class Porter: while True: - def r(txn): + def r( + txn: LoggingTransaction, + ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]: forward_rows = [] backward_rows = [] if do_forward[0]: @@ -418,6 +452,7 @@ class Porter: ) if frows or brows: + assert headers is not None if frows: forward_chunk = max(row[0] for row in frows) + 1 if brows: @@ -426,7 +461,8 @@ class Porter: rows = frows + brows rows = self._convert_rows(table, headers, rows) - def insert(txn): + def insert(txn: LoggingTransaction) -> None: + assert headers is not None self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store.db_pool.simple_update_one_txn( @@ -448,8 +484,12 @@ class Porter: return async def handle_search_table( - self, postgres_size, table_size, forward_chunk, backward_chunk - ): + self, + postgres_size: int, + table_size: int, + forward_chunk: int, + backward_chunk: int, + ) -> None: select = ( "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" " FROM event_search as es" @@ -460,7 +500,7 @@ class Porter: while True: - def r(txn): + def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() headers = [column[0] for column in txn.description] @@ -474,7 +514,7 @@ class Porter: # We have to treat event_search differently since it has a # different structure in the two different databases. - def insert(txn): + def insert(txn: LoggingTransaction) -> None: sql = ( "INSERT INTO event_search (event_id, room_id, key," " sender, vector, origin_server_ts, stream_ordering)" @@ -528,7 +568,7 @@ class Porter: self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False, - ): + ) -> Store: """Builds and returns a database store using the provided configuration. Args: @@ -556,7 +596,7 @@ class Porter: return store - async def run_background_updates_on_postgres(self): + async def run_background_updates_on_postgres(self) -> None: # Manually apply all background updates on the PostgreSQL database. postgres_ready = ( await self.postgres_store.db_pool.updates.has_completed_background_updates() @@ -568,12 +608,12 @@ class Porter: self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - await self.postgres_store.db_pool.updates.do_next_background_update(100) + await self.postgres_store.db_pool.updates.do_next_background_update(True) postgres_ready = await ( self.postgres_store.db_pool.updates.has_completed_background_updates() ) - async def run(self): + async def run(self) -> None: """Ports the SQLite database to a PostgreSQL database. When a fatal error is met, its message is assigned to the global "end_error" @@ -609,7 +649,7 @@ class Porter: self.progress.set_state("Creating port tables") - def create_port_table(txn): + def create_port_table(txn: LoggingTransaction) -> None: txn.execute( "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" " table_name varchar(100) NOT NULL UNIQUE," @@ -622,7 +662,7 @@ class Porter: # We want people to be able to rerun this script from an old port # so that they can pick up any missing events that were not # ported across. - def alter_table(txn): + def alter_table(txn: LoggingTransaction) -> None: txn.execute( "ALTER TABLE IF EXISTS port_from_sqlite3" " RENAME rowid TO forward_rowid" @@ -742,7 +782,9 @@ class Porter: finally: reactor.stop() - def _convert_rows(self, table, headers, rows): + def _convert_rows( + self, table: str, headers: List[str], rows: List[Tuple] + ) -> List[Tuple]: bool_col_names = BOOLEAN_COLUMNS.get(table, []) bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] @@ -750,7 +792,7 @@ class Porter: class BadValueException(Exception): pass - def conv(j, col): + def conv(j: int, col: object) -> object: if j in bool_cols: return bool(col) if isinstance(col, bytes): @@ -776,7 +818,7 @@ class Porter: return outrows - async def _setup_sent_transactions(self): + async def _setup_sent_transactions(self) -> Tuple[int, int, int]: # Only save things from the last day yesterday = int(time.time() * 1000) - 86400000 @@ -788,10 +830,10 @@ class Porter: ")" ) - def r(txn): + def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers = [column[0] for column in txn.description] + headers: List[str] = [column[0] for column in txn.description] ts_ind = headers.index("ts") @@ -805,7 +847,7 @@ class Porter: if inserted_rows: max_inserted_rowid = max(r[0] for r in rows) - def insert(txn): + def insert(txn: LoggingTransaction) -> None: self.postgres_store.insert_many_txn( txn, "sent_transactions", headers[1:], rows ) @@ -814,7 +856,7 @@ class Porter: else: max_inserted_rowid = 0 - def get_start_id(txn): + def get_start_id(txn: LoggingTransaction) -> int: txn.execute( "SELECT rowid FROM sent_transactions WHERE ts >= ?" " ORDER BY rowid ASC LIMIT 1", @@ -839,12 +881,13 @@ class Porter: }, ) - def get_sent_table_size(txn): + def get_sent_table_size(txn: LoggingTransaction) -> int: txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - (size,) = txn.fetchone() - return int(size) + result = txn.fetchone() + assert result is not None + return int(result[0]) remaining_count = await self.sqlite_store.execute(get_sent_table_size) @@ -852,25 +895,35 @@ class Porter: return next_chunk, inserted_rows, total_count - async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): - frows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk + async def _get_remaining_count_to_port( + self, table: str, forward_chunk: int, backward_chunk: int + ) -> int: + frows = cast( + List[Tuple[int]], + await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk + ), ) - brows = await self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk + brows = cast( + List[Tuple[int]], + await self.sqlite_store.execute_sql( + "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk + ), ) return frows[0][0] + brows[0][0] - async def _get_already_ported_count(self, table): + async def _get_already_ported_count(self, table: str) -> int: rows = await self.postgres_store.execute_sql( "SELECT count(*) FROM %s" % (table,) ) return rows[0][0] - async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): + async def _get_total_count_to_port( + self, table: str, forward_chunk: int, backward_chunk: int + ) -> Tuple[int, int]: remaining, done = await make_deferred_yieldable( defer.gatherResults( [ @@ -891,14 +944,17 @@ class Porter: return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) if not curr_id: return - def r(txn): + def r(txn: LoggingTransaction) -> None: + assert curr_id is not None next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) @@ -909,7 +965,7 @@ class Porter: "setup_user_id_seq", find_max_generated_user_id_localpart ) - def r(txn): + def r(txn: LoggingTransaction) -> None: next_id = curr_id + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) @@ -931,7 +987,7 @@ class Porter: allow_none=True, ) - def _setup_events_stream_seqs_set_pos(txn): + def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None: if curr_forward_id: txn.execute( "ALTER SEQUENCE events_stream_seq RESTART WITH %s", @@ -955,17 +1011,20 @@ class Porter: """Set a sequence to the correct value.""" current_stream_ids = [] for stream_id_table in stream_id_tables: - max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table=stream_id_table, - keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", - allow_none=True, + max_stream_id = cast( + int, + await self.sqlite_store.db_pool.simple_select_one_onecol( + table=stream_id_table, + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ), ) current_stream_ids.append(max_stream_id) next_id = max(current_stream_ids) + 1 - def r(txn): + def r(txn: LoggingTransaction) -> None: sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,) txn.execute(sql + " %s", (next_id,)) @@ -974,14 +1033,18 @@ class Porter: ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_chain_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True, ) - def r(txn): + def r(txn: LoggingTransaction) -> None: + # Presumably there is at least one row in event_auth_chains. + assert curr_chain_id is not None txn.execute( "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", (curr_chain_id + 1,), @@ -999,15 +1062,22 @@ class Porter: ############################################## -class Progress(object): +class TableProgress(TypedDict): + start: int + num_done: int + total: int + perc: int + + +class Progress: """Used to report progress of the port""" - def __init__(self): - self.tables = {} + def __init__(self) -> None: + self.tables: Dict[str, TableProgress] = {} self.start_time = int(time.time()) - def add_table(self, table, cur, size): + def add_table(self, table: str, cur: int, size: int) -> None: self.tables[table] = { "start": cur, "num_done": cur, @@ -1015,19 +1085,22 @@ class Progress(object): "perc": int(cur * 100 / size), } - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: data = self.tables[table] data["num_done"] = num_done data["perc"] = int(num_done * 100 / data["total"]) - def done(self): + def done(self) -> None: + pass + + def set_state(self, state: str) -> None: pass class CursesProgress(Progress): """Reports progress to a curses window""" - def __init__(self, stdscr): + def __init__(self, stdscr: "curses.window"): self.stdscr = stdscr curses.use_default_colors() @@ -1045,7 +1118,7 @@ class CursesProgress(Progress): super(CursesProgress, self).__init__() - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: super(CursesProgress, self).update(table, num_done) self.total_processed = 0 @@ -1056,7 +1129,7 @@ class CursesProgress(Progress): self.render() - def render(self, force=False): + def render(self, force: bool = False) -> None: now = time.time() if not force and now - self.last_update < 0.2: @@ -1128,12 +1201,12 @@ class CursesProgress(Progress): self.stdscr.refresh() self.last_update = time.time() - def done(self): + def done(self) -> None: self.finished = True self.render(True) self.stdscr.getch() - def set_state(self, state): + def set_state(self, state: str) -> None: self.stdscr.clear() self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) self.stdscr.refresh() @@ -1142,7 +1215,7 @@ class CursesProgress(Progress): class TerminalProgress(Progress): """Just prints progress to the terminal""" - def update(self, table, num_done): + def update(self, table: str, num_done: int) -> None: super(TerminalProgress, self).update(table, num_done) data = self.tables[table] @@ -1151,7 +1224,7 @@ class TerminalProgress(Progress): "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) ) - def set_state(self, state): + def set_state(self, state: str) -> None: print(state + "...") @@ -1159,7 +1232,7 @@ class TerminalProgress(Progress): ############################################## -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." @@ -1225,7 +1298,7 @@ def main(): config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") - def start(stdscr=None): + def start(stdscr: Optional["curses.window"] = None) -> None: progress: Progress if stdscr: progress = CursesProgress(stdscr) @@ -1240,7 +1313,7 @@ def main(): ) @defer.inlineCallbacks - def run(): + def run() -> Generator["defer.Deferred[Any]", Any, None]: with LoggingContext("synapse_port_db_run"): yield defer.ensureDeferred(porter.run()) diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py index 1ab36949c..b4c96ad7f 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py @@ -24,7 +24,7 @@ import signal import subprocess import sys import time -from typing import Iterable, Optional +from typing import Iterable, NoReturn, Optional, TextIO import yaml @@ -45,7 +45,7 @@ one of the following: --------------------------------------------------------------------------------""" -def pid_running(pid): +def pid_running(pid: int) -> bool: try: os.kill(pid, 0) except OSError as err: @@ -68,7 +68,7 @@ def pid_running(pid): return True -def write(message, colour=NORMAL, stream=sys.stdout): +def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None: # Lets check if we're writing to a TTY before colouring should_colour = False try: @@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout): stream.write(colour + message + NORMAL + "\n") -def abort(message, colour=RED, stream=sys.stderr): +def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn: write(message, colour, stream) sys.exit(1) @@ -166,7 +166,7 @@ Worker = collections.namedtuple( ) -def main(): +def main() -> None: parser = argparse.ArgumentParser() diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index 736f58836..c443522c0 100755 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -38,25 +38,25 @@ logger = logging.getLogger("update_database") class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore # type: ignore [assignment] - def __init__(self, config, **kwargs): + def __init__(self, config: HomeServerConfig): super(MockHomeserver, self).__init__( - config.server.server_name, reactor=reactor, config=config, **kwargs - ) - - self.version_string = "Synapse/" + get_distribution_version_string( - "matrix-synapse" + hostname=config.server.server_name, + config=config, + reactor=reactor, + version_string="Synapse/" + + get_distribution_version_string("matrix-synapse"), ) -def run_background_updates(hs): +def run_background_updates(hs: HomeServer) -> None: store = hs.get_datastores().main - async def run_background_updates(): + async def run_background_updates() -> None: await store.db_pool.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() - def run(): + def run() -> None: # Apply all background updates on the database. defer.ensureDeferred( run_as_background_process("background_updates", run_background_updates) @@ -67,7 +67,7 @@ def run_background_updates(hs): reactor.run() -def main(): +def main() -> None: parser = argparse.ArgumentParser( description=( "Updates a synapse database to the latest schema and optionally runs background updates" diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 57f4883bf..d7d6f1d90 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -45,6 +45,7 @@ class Cursor(Protocol): Sequence[ # Note that this is an approximate typing based on sqlite3 and other # drivers, and may not be entirely accurate. + # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description Tuple[ str, Optional[Any], diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py index 6f3c365c9..19a145eeb 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration @@ -52,16 +52,16 @@ class RegisterTestCase(TestCase): out = [] err_code = [] - request_registration( - "user", - "pass", - "matrix.org", - "shared", - admin=False, - requests=requests, - _print=out.append, - exit=err_code.append, - ) + with patch("synapse._scripts.register_new_matrix_user.requests", requests): + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + _print=out.append, + exit=err_code.append, + ) # We should get the success message making sure everything is OK. self.assertIn("Success!", out) @@ -88,16 +88,16 @@ class RegisterTestCase(TestCase): out = [] err_code = [] - request_registration( - "user", - "pass", - "matrix.org", - "shared", - admin=False, - requests=requests, - _print=out.append, - exit=err_code.append, - ) + with patch("synapse._scripts.register_new_matrix_user.requests", requests): + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + _print=out.append, + exit=err_code.append, + ) # Exit was called self.assertEqual(err_code, [1]) @@ -140,16 +140,16 @@ class RegisterTestCase(TestCase): out = [] err_code = [] - request_registration( - "user", - "pass", - "matrix.org", - "shared", - admin=False, - requests=requests, - _print=out.append, - exit=err_code.append, - ) + with patch("synapse._scripts.register_new_matrix_user.requests", requests): + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + _print=out.append, + exit=err_code.append, + ) # Exit was called self.assertEqual(err_code, [1])