mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Disallow untyped defs in synapse._scripts (#12422)
Of note: * No untyped defs in `register_new_matrix_user` This one might be contraversial. `request_registration` has three dependency-injection arguments used for testing. I'm removing the injection of the `requests` module and using `unitest.mock.patch` in the test cases instead. Doing `reveal_type(requests)` and `reveal_type(requests.get)` before the change: ``` synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "Any" synapse/_scripts/register_new_matrix_user.py:46: note: Revealed type is "Any" ``` And after: ``` synapse/_scripts/register_new_matrix_user.py:44: note: Revealed type is "types.ModuleType" synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "def (url: Union[builtins.str, builtins.bytes], params: Union[Union[_typeshed.SupportsItems[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], typing.Iterable[Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]]], builtins.str, builtins.bytes], None] =, data: Union[Any, None] =, headers: Union[Any, None] =, cookies: Union[Any, None] =, files: Union[Any, None] =, auth: Union[Any, None] =, timeout: Union[Any, None] =, allow_redirects: builtins.bool =, proxies: Union[Any, None] =, hooks: Union[Any, None] =, stream: Union[Any, None] =, verify: Union[Any, None] =, cert: Union[Any, None] =, json: Union[Any, None] =) -> requests.models.Response" ``` * Drive-by comment in `synapse.storage.types` * No untyped defs in `synapse_port_db` This was by far the most painful. I'm happy to break this up into smaller pieces for review if it's not managable as-is.
This commit is contained in:
parent
5f72ea1bde
commit
961ee75a9b
1
changelog.d/12422.misc
Normal file
1
changelog.d/12422.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make `synapse._scripts` pass type checks.
|
3
mypy.ini
3
mypy.ini
@ -93,6 +93,9 @@ exclude = (?x)
|
|||||||
|tests/utils.py
|
|tests/utils.py
|
||||||
)$
|
)$
|
||||||
|
|
||||||
|
[mypy-synapse._scripts.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.api.*]
|
[mypy-synapse.api.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -15,19 +15,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import NoReturn, Optional
|
||||||
|
|
||||||
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
||||||
from signedjson.types import VerifyKey
|
from signedjson.types import VerifyKey
|
||||||
|
|
||||||
|
|
||||||
def exit(status: int = 0, message: Optional[str] = None):
|
def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
|
||||||
if message:
|
if message:
|
||||||
print(message, file=sys.stderr)
|
print(message, file=sys.stderr)
|
||||||
sys.exit(status)
|
sys.exit(status)
|
||||||
|
|
||||||
|
|
||||||
def format_plain(public_key: VerifyKey):
|
def format_plain(public_key: VerifyKey) -> None:
|
||||||
print(
|
print(
|
||||||
"%s:%s %s"
|
"%s:%s %s"
|
||||||
% (
|
% (
|
||||||
@ -38,7 +38,7 @@ def format_plain(public_key: VerifyKey):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_for_config(public_key: VerifyKey, expiry_ts: int):
|
def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
|
||||||
print(
|
print(
|
||||||
' "%s:%s": { key: "%s", expired_ts: %i }'
|
' "%s:%s": { key: "%s", expired_ts: %i }'
|
||||||
% (
|
% (
|
||||||
@ -50,7 +50,7 @@ def format_for_config(public_key: VerifyKey, expiry_ts: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -94,7 +94,6 @@ def main():
|
|||||||
message="Error reading key from file %s: %s %s"
|
message="Error reading key from file %s: %s %s"
|
||||||
% (file.name, type(e), e),
|
% (file.name, type(e), e),
|
||||||
)
|
)
|
||||||
res = []
|
|
||||||
for key in res:
|
for key in res:
|
||||||
formatter(get_verify_key(key))
|
formatter(get_verify_key(key))
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import sys
|
|||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config-dir",
|
"--config-dir",
|
||||||
|
@ -20,7 +20,7 @@ import sys
|
|||||||
from synapse.config.logger import DEFAULT_LOG_CONFIG
|
from synapse.config.logger import DEFAULT_LOG_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
|
|||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -9,7 +9,7 @@ import bcrypt
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def prompt_for_pass():
|
def prompt_for_pass() -> str:
|
||||||
password = getpass.getpass("Password: ")
|
password = getpass.getpass("Password: ")
|
||||||
|
|
||||||
if not password:
|
if not password:
|
||||||
@ -23,7 +23,7 @@ def prompt_for_pass():
|
|||||||
return password
|
return password
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
bcrypt_rounds = 12
|
bcrypt_rounds = 12
|
||||||
password_pepper = ""
|
password_pepper = ""
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def main(src_repo, dest_repo):
|
def main(src_repo: str, dest_repo: str) -> None:
|
||||||
src_paths = MediaFilePaths(src_repo)
|
src_paths = MediaFilePaths(src_repo)
|
||||||
dest_paths = MediaFilePaths(dest_repo)
|
dest_paths = MediaFilePaths(dest_repo)
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
|
|||||||
move_media(parts[0], parts[1], src_paths, dest_paths)
|
move_media(parts[0], parts[1], src_paths, dest_paths)
|
||||||
|
|
||||||
|
|
||||||
def move_media(origin_server, file_id, src_paths, dest_paths):
|
def move_media(
|
||||||
|
origin_server: str,
|
||||||
|
file_id: str,
|
||||||
|
src_paths: MediaFilePaths,
|
||||||
|
dest_paths: MediaFilePaths,
|
||||||
|
) -> None:
|
||||||
"""Move the given file, and any thumbnails, to the dest repo
|
"""Move the given file, and any thumbnails, to the dest repo
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin_server (str):
|
origin_server:
|
||||||
file_id (str):
|
file_id:
|
||||||
src_paths (MediaFilePaths):
|
src_paths:
|
||||||
dest_paths (MediaFilePaths):
|
dest_paths:
|
||||||
"""
|
"""
|
||||||
logger.info("%s/%s", origin_server, file_id)
|
logger.info("%s/%s", origin_server, file_id)
|
||||||
|
|
||||||
@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mkdir_and_move(original_file, dest_file):
|
def mkdir_and_move(original_file: str, dest_file: str) -> None:
|
||||||
dirname = os.path.dirname(dest_file)
|
dirname = os.path.dirname(dest_file)
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
logger.debug("mkdir %s", dirname)
|
logger.debug("mkdir %s", dirname)
|
||||||
|
@ -22,7 +22,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import requests as _requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +33,6 @@ def request_registration(
|
|||||||
shared_secret: str,
|
shared_secret: str,
|
||||||
admin: bool = False,
|
admin: bool = False,
|
||||||
user_type: Optional[str] = None,
|
user_type: Optional[str] = None,
|
||||||
requests=_requests,
|
|
||||||
_print: Callable[[str], None] = print,
|
_print: Callable[[str], None] = print,
|
||||||
exit: Callable[[int], None] = sys.exit,
|
exit: Callable[[int], None] = sys.exit,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -22,10 +22,26 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
NoReturn,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from matrix_common.versionstring import get_distribution_version_string
|
from matrix_common.versionstring import get_distribution_version_string
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from twisted.internet import defer, reactor as reactor_
|
from twisted.internet import defer, reactor as reactor_
|
||||||
|
|
||||||
@ -36,7 +52,7 @@ from synapse.logging.context import (
|
|||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
run_in_background,
|
run_in_background,
|
||||||
)
|
)
|
||||||
from synapse.storage.database import DatabasePool, make_conn
|
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||||
from synapse.storage.databases.main import PushRuleStore
|
from synapse.storage.databases.main import PushRuleStore
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||||
@ -173,6 +189,8 @@ end_error_exec_info: Optional[
|
|||||||
Tuple[Type[BaseException], BaseException, TracebackType]
|
Tuple[Type[BaseException], BaseException, TracebackType]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class Store(
|
class Store(
|
||||||
ClientIpBackgroundUpdateStore,
|
ClientIpBackgroundUpdateStore,
|
||||||
@ -195,17 +213,19 @@ class Store(
|
|||||||
PresenceBackgroundUpdateStore,
|
PresenceBackgroundUpdateStore,
|
||||||
GroupServerWorkerStore,
|
GroupServerWorkerStore,
|
||||||
):
|
):
|
||||||
def execute(self, f, *args, **kwargs):
|
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, *args):
|
def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> List[Tuple]:
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
return self.db_pool.runInteraction("execute_sql", r)
|
return self.db_pool.runInteraction("execute_sql", r)
|
||||||
|
|
||||||
def insert_many_txn(self, txn, table, headers, rows):
|
def insert_many_txn(
|
||||||
|
self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
|
||||||
|
) -> None:
|
||||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||||
table,
|
table,
|
||||||
", ".join(k for k in headers),
|
", ".join(k for k in headers),
|
||||||
@ -218,14 +238,15 @@ class Store(
|
|||||||
logger.exception("Failed to insert: %s", table)
|
logger.exception("Failed to insert: %s", table)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def set_room_is_public(self, room_id, is_public):
|
# Note: the parent method is an `async def`.
|
||||||
|
def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Attempt to set room_is_public during port_db: database not empty?"
|
"Attempt to set room_is_public during port_db: database not empty?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockHomeserver:
|
class MockHomeserver:
|
||||||
def __init__(self, config):
|
def __init__(self, config: HomeServerConfig):
|
||||||
self.clock = Clock(reactor)
|
self.clock = Clock(reactor)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hostname = config.server.server_name
|
self.hostname = config.server.server_name
|
||||||
@ -233,24 +254,30 @@ class MockHomeserver:
|
|||||||
"matrix-synapse"
|
"matrix-synapse"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_clock(self):
|
def get_clock(self) -> Clock:
|
||||||
return self.clock
|
return self.clock
|
||||||
|
|
||||||
def get_reactor(self):
|
def get_reactor(self) -> ISynapseReactor:
|
||||||
return reactor
|
return reactor
|
||||||
|
|
||||||
def get_instance_name(self):
|
def get_instance_name(self) -> str:
|
||||||
return "master"
|
return "master"
|
||||||
|
|
||||||
|
|
||||||
class Porter:
|
class Porter:
|
||||||
def __init__(self, sqlite_config, progress, batch_size, hs_config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
sqlite_config: Dict[str, Any],
|
||||||
|
progress: "Progress",
|
||||||
|
batch_size: int,
|
||||||
|
hs_config: HomeServerConfig,
|
||||||
|
):
|
||||||
self.sqlite_config = sqlite_config
|
self.sqlite_config = sqlite_config
|
||||||
self.progress = progress
|
self.progress = progress
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.hs_config = hs_config
|
self.hs_config = hs_config
|
||||||
|
|
||||||
async def setup_table(self, table):
|
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
|
||||||
if table in APPEND_ONLY_TABLES:
|
if table in APPEND_ONLY_TABLES:
|
||||||
# It's safe to just carry on inserting.
|
# It's safe to just carry on inserting.
|
||||||
row = await self.postgres_store.db_pool.simple_select_one(
|
row = await self.postgres_store.db_pool.simple_select_one(
|
||||||
@ -292,7 +319,7 @@ class Porter:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def delete_all(txn):
|
def delete_all(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
|
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
|
||||||
)
|
)
|
||||||
@ -317,7 +344,7 @@ class Porter:
|
|||||||
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
||||||
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
|
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
|
||||||
|
|
||||||
def _get_constraints(txn):
|
def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
|
||||||
# We can pull the information about foreign key constraints out from
|
# We can pull the information about foreign key constraints out from
|
||||||
# the postgres schema tables.
|
# the postgres schema tables.
|
||||||
sql = """
|
sql = """
|
||||||
@ -343,8 +370,13 @@ class Porter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def handle_table(
|
async def handle_table(
|
||||||
self, table, postgres_size, table_size, forward_chunk, backward_chunk
|
self,
|
||||||
):
|
table: str,
|
||||||
|
postgres_size: int,
|
||||||
|
table_size: int,
|
||||||
|
forward_chunk: int,
|
||||||
|
backward_chunk: int,
|
||||||
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Table %s: %i/%i (rows %i-%i) already ported",
|
"Table %s: %i/%i (rows %i-%i) already ported",
|
||||||
table,
|
table,
|
||||||
@ -391,7 +423,9 @@ class Porter:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
def r(txn):
|
def r(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
|
||||||
forward_rows = []
|
forward_rows = []
|
||||||
backward_rows = []
|
backward_rows = []
|
||||||
if do_forward[0]:
|
if do_forward[0]:
|
||||||
@ -418,6 +452,7 @@ class Porter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if frows or brows:
|
if frows or brows:
|
||||||
|
assert headers is not None
|
||||||
if frows:
|
if frows:
|
||||||
forward_chunk = max(row[0] for row in frows) + 1
|
forward_chunk = max(row[0] for row in frows) + 1
|
||||||
if brows:
|
if brows:
|
||||||
@ -426,7 +461,8 @@ class Porter:
|
|||||||
rows = frows + brows
|
rows = frows + brows
|
||||||
rows = self._convert_rows(table, headers, rows)
|
rows = self._convert_rows(table, headers, rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn: LoggingTransaction) -> None:
|
||||||
|
assert headers is not None
|
||||||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
|
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
|
||||||
|
|
||||||
self.postgres_store.db_pool.simple_update_one_txn(
|
self.postgres_store.db_pool.simple_update_one_txn(
|
||||||
@ -448,8 +484,12 @@ class Porter:
|
|||||||
return
|
return
|
||||||
|
|
||||||
async def handle_search_table(
|
async def handle_search_table(
|
||||||
self, postgres_size, table_size, forward_chunk, backward_chunk
|
self,
|
||||||
):
|
postgres_size: int,
|
||||||
|
table_size: int,
|
||||||
|
forward_chunk: int,
|
||||||
|
backward_chunk: int,
|
||||||
|
) -> None:
|
||||||
select = (
|
select = (
|
||||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||||
" FROM event_search as es"
|
" FROM event_search as es"
|
||||||
@ -460,7 +500,7 @@ class Porter:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
|
||||||
txn.execute(select, (forward_chunk, self.batch_size))
|
txn.execute(select, (forward_chunk, self.batch_size))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
headers = [column[0] for column in txn.description]
|
headers = [column[0] for column in txn.description]
|
||||||
@ -474,7 +514,7 @@ class Porter:
|
|||||||
|
|
||||||
# We have to treat event_search differently since it has a
|
# We have to treat event_search differently since it has a
|
||||||
# different structure in the two different databases.
|
# different structure in the two different databases.
|
||||||
def insert(txn):
|
def insert(txn: LoggingTransaction) -> None:
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT INTO event_search (event_id, room_id, key,"
|
"INSERT INTO event_search (event_id, room_id, key,"
|
||||||
" sender, vector, origin_server_ts, stream_ordering)"
|
" sender, vector, origin_server_ts, stream_ordering)"
|
||||||
@ -528,7 +568,7 @@ class Porter:
|
|||||||
self,
|
self,
|
||||||
db_config: DatabaseConnectionConfig,
|
db_config: DatabaseConnectionConfig,
|
||||||
allow_outdated_version: bool = False,
|
allow_outdated_version: bool = False,
|
||||||
):
|
) -> Store:
|
||||||
"""Builds and returns a database store using the provided configuration.
|
"""Builds and returns a database store using the provided configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -556,7 +596,7 @@ class Porter:
|
|||||||
|
|
||||||
return store
|
return store
|
||||||
|
|
||||||
async def run_background_updates_on_postgres(self):
|
async def run_background_updates_on_postgres(self) -> None:
|
||||||
# Manually apply all background updates on the PostgreSQL database.
|
# Manually apply all background updates on the PostgreSQL database.
|
||||||
postgres_ready = (
|
postgres_ready = (
|
||||||
await self.postgres_store.db_pool.updates.has_completed_background_updates()
|
await self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||||
@ -568,12 +608,12 @@ class Porter:
|
|||||||
self.progress.set_state("Running background updates on PostgreSQL")
|
self.progress.set_state("Running background updates on PostgreSQL")
|
||||||
|
|
||||||
while not postgres_ready:
|
while not postgres_ready:
|
||||||
await self.postgres_store.db_pool.updates.do_next_background_update(100)
|
await self.postgres_store.db_pool.updates.do_next_background_update(True)
|
||||||
postgres_ready = await (
|
postgres_ready = await (
|
||||||
self.postgres_store.db_pool.updates.has_completed_background_updates()
|
self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self) -> None:
|
||||||
"""Ports the SQLite database to a PostgreSQL database.
|
"""Ports the SQLite database to a PostgreSQL database.
|
||||||
|
|
||||||
When a fatal error is met, its message is assigned to the global "end_error"
|
When a fatal error is met, its message is assigned to the global "end_error"
|
||||||
@ -609,7 +649,7 @@ class Porter:
|
|||||||
|
|
||||||
self.progress.set_state("Creating port tables")
|
self.progress.set_state("Creating port tables")
|
||||||
|
|
||||||
def create_port_table(txn):
|
def create_port_table(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
|
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
|
||||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||||
@ -622,7 +662,7 @@ class Porter:
|
|||||||
# We want people to be able to rerun this script from an old port
|
# We want people to be able to rerun this script from an old port
|
||||||
# so that they can pick up any missing events that were not
|
# so that they can pick up any missing events that were not
|
||||||
# ported across.
|
# ported across.
|
||||||
def alter_table(txn):
|
def alter_table(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||||
" RENAME rowid TO forward_rowid"
|
" RENAME rowid TO forward_rowid"
|
||||||
@ -742,7 +782,9 @@ class Porter:
|
|||||||
finally:
|
finally:
|
||||||
reactor.stop()
|
reactor.stop()
|
||||||
|
|
||||||
def _convert_rows(self, table, headers, rows):
|
def _convert_rows(
|
||||||
|
self, table: str, headers: List[str], rows: List[Tuple]
|
||||||
|
) -> List[Tuple]:
|
||||||
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
|
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
|
||||||
|
|
||||||
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
|
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
|
||||||
@ -750,7 +792,7 @@ class Porter:
|
|||||||
class BadValueException(Exception):
|
class BadValueException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def conv(j, col):
|
def conv(j: int, col: object) -> object:
|
||||||
if j in bool_cols:
|
if j in bool_cols:
|
||||||
return bool(col)
|
return bool(col)
|
||||||
if isinstance(col, bytes):
|
if isinstance(col, bytes):
|
||||||
@ -776,7 +818,7 @@ class Porter:
|
|||||||
|
|
||||||
return outrows
|
return outrows
|
||||||
|
|
||||||
async def _setup_sent_transactions(self):
|
async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
|
||||||
# Only save things from the last day
|
# Only save things from the last day
|
||||||
yesterday = int(time.time() * 1000) - 86400000
|
yesterday = int(time.time() * 1000) - 86400000
|
||||||
|
|
||||||
@ -788,10 +830,10 @@ class Porter:
|
|||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
|
||||||
txn.execute(select)
|
txn.execute(select)
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
headers = [column[0] for column in txn.description]
|
headers: List[str] = [column[0] for column in txn.description]
|
||||||
|
|
||||||
ts_ind = headers.index("ts")
|
ts_ind = headers.index("ts")
|
||||||
|
|
||||||
@ -805,7 +847,7 @@ class Porter:
|
|||||||
if inserted_rows:
|
if inserted_rows:
|
||||||
max_inserted_rowid = max(r[0] for r in rows)
|
max_inserted_rowid = max(r[0] for r in rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn: LoggingTransaction) -> None:
|
||||||
self.postgres_store.insert_many_txn(
|
self.postgres_store.insert_many_txn(
|
||||||
txn, "sent_transactions", headers[1:], rows
|
txn, "sent_transactions", headers[1:], rows
|
||||||
)
|
)
|
||||||
@ -814,7 +856,7 @@ class Porter:
|
|||||||
else:
|
else:
|
||||||
max_inserted_rowid = 0
|
max_inserted_rowid = 0
|
||||||
|
|
||||||
def get_start_id(txn):
|
def get_start_id(txn: LoggingTransaction) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
|
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
|
||||||
" ORDER BY rowid ASC LIMIT 1",
|
" ORDER BY rowid ASC LIMIT 1",
|
||||||
@ -839,12 +881,13 @@ class Porter:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sent_table_size(txn):
|
def get_sent_table_size(txn: LoggingTransaction) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
|
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
|
||||||
)
|
)
|
||||||
(size,) = txn.fetchone()
|
result = txn.fetchone()
|
||||||
return int(size)
|
assert result is not None
|
||||||
|
return int(result[0])
|
||||||
|
|
||||||
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
|
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
|
||||||
|
|
||||||
@ -852,25 +895,35 @@ class Porter:
|
|||||||
|
|
||||||
return next_chunk, inserted_rows, total_count
|
return next_chunk, inserted_rows, total_count
|
||||||
|
|
||||||
async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
async def _get_remaining_count_to_port(
|
||||||
frows = await self.sqlite_store.execute_sql(
|
self, table: str, forward_chunk: int, backward_chunk: int
|
||||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
|
) -> int:
|
||||||
|
frows = cast(
|
||||||
|
List[Tuple[int]],
|
||||||
|
await self.sqlite_store.execute_sql(
|
||||||
|
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
brows = await self.sqlite_store.execute_sql(
|
brows = cast(
|
||||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
|
List[Tuple[int]],
|
||||||
|
await self.sqlite_store.execute_sql(
|
||||||
|
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return frows[0][0] + brows[0][0]
|
return frows[0][0] + brows[0][0]
|
||||||
|
|
||||||
async def _get_already_ported_count(self, table):
|
async def _get_already_ported_count(self, table: str) -> int:
|
||||||
rows = await self.postgres_store.execute_sql(
|
rows = await self.postgres_store.execute_sql(
|
||||||
"SELECT count(*) FROM %s" % (table,)
|
"SELECT count(*) FROM %s" % (table,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
|
|
||||||
async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
|
async def _get_total_count_to_port(
|
||||||
|
self, table: str, forward_chunk: int, backward_chunk: int
|
||||||
|
) -> Tuple[int, int]:
|
||||||
remaining, done = await make_deferred_yieldable(
|
remaining, done = await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[
|
[
|
||||||
@ -891,14 +944,17 @@ class Porter:
|
|||||||
return done, remaining + done
|
return done, remaining + done
|
||||||
|
|
||||||
async def _setup_state_group_id_seq(self) -> None:
|
async def _setup_state_group_id_seq(self) -> None:
|
||||||
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
curr_id: Optional[
|
||||||
|
int
|
||||||
|
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if not curr_id:
|
if not curr_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> None:
|
||||||
|
assert curr_id is not None
|
||||||
next_id = curr_id + 1
|
next_id = curr_id + 1
|
||||||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
|
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
|
||||||
|
|
||||||
@ -909,7 +965,7 @@ class Porter:
|
|||||||
"setup_user_id_seq", find_max_generated_user_id_localpart
|
"setup_user_id_seq", find_max_generated_user_id_localpart
|
||||||
)
|
)
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> None:
|
||||||
next_id = curr_id + 1
|
next_id = curr_id + 1
|
||||||
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
||||||
|
|
||||||
@ -931,7 +987,7 @@ class Porter:
|
|||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _setup_events_stream_seqs_set_pos(txn):
|
def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
|
||||||
if curr_forward_id:
|
if curr_forward_id:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
|
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
|
||||||
@ -955,17 +1011,20 @@ class Porter:
|
|||||||
"""Set a sequence to the correct value."""
|
"""Set a sequence to the correct value."""
|
||||||
current_stream_ids = []
|
current_stream_ids = []
|
||||||
for stream_id_table in stream_id_tables:
|
for stream_id_table in stream_id_tables:
|
||||||
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
max_stream_id = cast(
|
||||||
table=stream_id_table,
|
int,
|
||||||
keyvalues={},
|
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
retcol="COALESCE(MAX(stream_id), 1)",
|
table=stream_id_table,
|
||||||
allow_none=True,
|
keyvalues={},
|
||||||
|
retcol="COALESCE(MAX(stream_id), 1)",
|
||||||
|
allow_none=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
current_stream_ids.append(max_stream_id)
|
current_stream_ids.append(max_stream_id)
|
||||||
|
|
||||||
next_id = max(current_stream_ids) + 1
|
next_id = max(current_stream_ids) + 1
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> None:
|
||||||
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
||||||
txn.execute(sql + " %s", (next_id,))
|
txn.execute(sql + " %s", (next_id,))
|
||||||
|
|
||||||
@ -974,14 +1033,18 @@ class Porter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _setup_auth_chain_sequence(self) -> None:
|
async def _setup_auth_chain_sequence(self) -> None:
|
||||||
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
curr_chain_id: Optional[
|
||||||
|
int
|
||||||
|
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="event_auth_chains",
|
table="event_auth_chains",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcol="MAX(chain_id)",
|
retcol="MAX(chain_id)",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def r(txn):
|
def r(txn: LoggingTransaction) -> None:
|
||||||
|
# Presumably there is at least one row in event_auth_chains.
|
||||||
|
assert curr_chain_id is not None
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
||||||
(curr_chain_id + 1,),
|
(curr_chain_id + 1,),
|
||||||
@ -999,15 +1062,22 @@ class Porter:
|
|||||||
##############################################
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
class Progress(object):
|
class TableProgress(TypedDict):
|
||||||
|
start: int
|
||||||
|
num_done: int
|
||||||
|
total: int
|
||||||
|
perc: int
|
||||||
|
|
||||||
|
|
||||||
|
class Progress:
|
||||||
"""Used to report progress of the port"""
|
"""Used to report progress of the port"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.tables = {}
|
self.tables: Dict[str, TableProgress] = {}
|
||||||
|
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
def add_table(self, table, cur, size):
|
def add_table(self, table: str, cur: int, size: int) -> None:
|
||||||
self.tables[table] = {
|
self.tables[table] = {
|
||||||
"start": cur,
|
"start": cur,
|
||||||
"num_done": cur,
|
"num_done": cur,
|
||||||
@ -1015,19 +1085,22 @@ class Progress(object):
|
|||||||
"perc": int(cur * 100 / size),
|
"perc": int(cur * 100 / size),
|
||||||
}
|
}
|
||||||
|
|
||||||
def update(self, table, num_done):
|
def update(self, table: str, num_done: int) -> None:
|
||||||
data = self.tables[table]
|
data = self.tables[table]
|
||||||
data["num_done"] = num_done
|
data["num_done"] = num_done
|
||||||
data["perc"] = int(num_done * 100 / data["total"])
|
data["perc"] = int(num_done * 100 / data["total"])
|
||||||
|
|
||||||
def done(self):
|
def done(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_state(self, state: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CursesProgress(Progress):
|
class CursesProgress(Progress):
|
||||||
"""Reports progress to a curses window"""
|
"""Reports progress to a curses window"""
|
||||||
|
|
||||||
def __init__(self, stdscr):
|
def __init__(self, stdscr: "curses.window"):
|
||||||
self.stdscr = stdscr
|
self.stdscr = stdscr
|
||||||
|
|
||||||
curses.use_default_colors()
|
curses.use_default_colors()
|
||||||
@ -1045,7 +1118,7 @@ class CursesProgress(Progress):
|
|||||||
|
|
||||||
super(CursesProgress, self).__init__()
|
super(CursesProgress, self).__init__()
|
||||||
|
|
||||||
def update(self, table, num_done):
|
def update(self, table: str, num_done: int) -> None:
|
||||||
super(CursesProgress, self).update(table, num_done)
|
super(CursesProgress, self).update(table, num_done)
|
||||||
|
|
||||||
self.total_processed = 0
|
self.total_processed = 0
|
||||||
@ -1056,7 +1129,7 @@ class CursesProgress(Progress):
|
|||||||
|
|
||||||
self.render()
|
self.render()
|
||||||
|
|
||||||
def render(self, force=False):
|
def render(self, force: bool = False) -> None:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if not force and now - self.last_update < 0.2:
|
if not force and now - self.last_update < 0.2:
|
||||||
@ -1128,12 +1201,12 @@ class CursesProgress(Progress):
|
|||||||
self.stdscr.refresh()
|
self.stdscr.refresh()
|
||||||
self.last_update = time.time()
|
self.last_update = time.time()
|
||||||
|
|
||||||
def done(self):
|
def done(self) -> None:
|
||||||
self.finished = True
|
self.finished = True
|
||||||
self.render(True)
|
self.render(True)
|
||||||
self.stdscr.getch()
|
self.stdscr.getch()
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state: str) -> None:
|
||||||
self.stdscr.clear()
|
self.stdscr.clear()
|
||||||
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
|
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
|
||||||
self.stdscr.refresh()
|
self.stdscr.refresh()
|
||||||
@ -1142,7 +1215,7 @@ class CursesProgress(Progress):
|
|||||||
class TerminalProgress(Progress):
|
class TerminalProgress(Progress):
|
||||||
"""Just prints progress to the terminal"""
|
"""Just prints progress to the terminal"""
|
||||||
|
|
||||||
def update(self, table, num_done):
|
def update(self, table: str, num_done: int) -> None:
|
||||||
super(TerminalProgress, self).update(table, num_done)
|
super(TerminalProgress, self).update(table, num_done)
|
||||||
|
|
||||||
data = self.tables[table]
|
data = self.tables[table]
|
||||||
@ -1151,7 +1224,7 @@ class TerminalProgress(Progress):
|
|||||||
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
|
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state: str) -> None:
|
||||||
print(state + "...")
|
print(state + "...")
|
||||||
|
|
||||||
|
|
||||||
@ -1159,7 +1232,7 @@ class TerminalProgress(Progress):
|
|||||||
##############################################
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="A script to port an existing synapse SQLite database to"
|
description="A script to port an existing synapse SQLite database to"
|
||||||
" a new PostgreSQL database."
|
" a new PostgreSQL database."
|
||||||
@ -1225,7 +1298,7 @@ def main():
|
|||||||
config = HomeServerConfig()
|
config = HomeServerConfig()
|
||||||
config.parse_config_dict(hs_config, "", "")
|
config.parse_config_dict(hs_config, "", "")
|
||||||
|
|
||||||
def start(stdscr=None):
|
def start(stdscr: Optional["curses.window"] = None) -> None:
|
||||||
progress: Progress
|
progress: Progress
|
||||||
if stdscr:
|
if stdscr:
|
||||||
progress = CursesProgress(stdscr)
|
progress = CursesProgress(stdscr)
|
||||||
@ -1240,7 +1313,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def run():
|
def run() -> Generator["defer.Deferred[Any]", Any, None]:
|
||||||
with LoggingContext("synapse_port_db_run"):
|
with LoggingContext("synapse_port_db_run"):
|
||||||
yield defer.ensureDeferred(porter.run())
|
yield defer.ensureDeferred(porter.run())
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ import signal
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, NoReturn, Optional, TextIO
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ one of the following:
|
|||||||
--------------------------------------------------------------------------------"""
|
--------------------------------------------------------------------------------"""
|
||||||
|
|
||||||
|
|
||||||
def pid_running(pid):
|
def pid_running(pid: int) -> bool:
|
||||||
try:
|
try:
|
||||||
os.kill(pid, 0)
|
os.kill(pid, 0)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
@ -68,7 +68,7 @@ def pid_running(pid):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
|
||||||
# Lets check if we're writing to a TTY before colouring
|
# Lets check if we're writing to a TTY before colouring
|
||||||
should_colour = False
|
should_colour = False
|
||||||
try:
|
try:
|
||||||
@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
|||||||
stream.write(colour + message + NORMAL + "\n")
|
stream.write(colour + message + NORMAL + "\n")
|
||||||
|
|
||||||
|
|
||||||
def abort(message, colour=RED, stream=sys.stderr):
|
def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
|
||||||
write(message, colour, stream)
|
write(message, colour, stream)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ Worker = collections.namedtuple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
@ -38,25 +38,25 @@ logger = logging.getLogger("update_database")
|
|||||||
class MockHomeserver(HomeServer):
|
class MockHomeserver(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config: HomeServerConfig):
|
||||||
super(MockHomeserver, self).__init__(
|
super(MockHomeserver, self).__init__(
|
||||||
config.server.server_name, reactor=reactor, config=config, **kwargs
|
hostname=config.server.server_name,
|
||||||
)
|
config=config,
|
||||||
|
reactor=reactor,
|
||||||
self.version_string = "Synapse/" + get_distribution_version_string(
|
version_string="Synapse/"
|
||||||
"matrix-synapse"
|
+ get_distribution_version_string("matrix-synapse"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_background_updates(hs):
|
def run_background_updates(hs: HomeServer) -> None:
|
||||||
store = hs.get_datastores().main
|
store = hs.get_datastores().main
|
||||||
|
|
||||||
async def run_background_updates():
|
async def run_background_updates() -> None:
|
||||||
await store.db_pool.updates.run_background_updates(sleep=False)
|
await store.db_pool.updates.run_background_updates(sleep=False)
|
||||||
# Stop the reactor to exit the script once every background update is run.
|
# Stop the reactor to exit the script once every background update is run.
|
||||||
reactor.stop()
|
reactor.stop()
|
||||||
|
|
||||||
def run():
|
def run() -> None:
|
||||||
# Apply all background updates on the database.
|
# Apply all background updates on the database.
|
||||||
defer.ensureDeferred(
|
defer.ensureDeferred(
|
||||||
run_as_background_process("background_updates", run_background_updates)
|
run_as_background_process("background_updates", run_background_updates)
|
||||||
@ -67,7 +67,7 @@ def run_background_updates(hs):
|
|||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=(
|
description=(
|
||||||
"Updates a synapse database to the latest schema and optionally runs background updates"
|
"Updates a synapse database to the latest schema and optionally runs background updates"
|
||||||
|
@ -45,6 +45,7 @@ class Cursor(Protocol):
|
|||||||
Sequence[
|
Sequence[
|
||||||
# Note that this is an approximate typing based on sqlite3 and other
|
# Note that this is an approximate typing based on sqlite3 and other
|
||||||
# drivers, and may not be entirely accurate.
|
# drivers, and may not be entirely accurate.
|
||||||
|
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
|
||||||
Tuple[
|
Tuple[
|
||||||
str,
|
str,
|
||||||
Optional[Any],
|
Optional[Any],
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from synapse._scripts.register_new_matrix_user import request_registration
|
from synapse._scripts.register_new_matrix_user import request_registration
|
||||||
|
|
||||||
@ -52,16 +52,16 @@ class RegisterTestCase(TestCase):
|
|||||||
out = []
|
out = []
|
||||||
err_code = []
|
err_code = []
|
||||||
|
|
||||||
request_registration(
|
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
|
||||||
"user",
|
request_registration(
|
||||||
"pass",
|
"user",
|
||||||
"matrix.org",
|
"pass",
|
||||||
"shared",
|
"matrix.org",
|
||||||
admin=False,
|
"shared",
|
||||||
requests=requests,
|
admin=False,
|
||||||
_print=out.append,
|
_print=out.append,
|
||||||
exit=err_code.append,
|
exit=err_code.append,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should get the success message making sure everything is OK.
|
# We should get the success message making sure everything is OK.
|
||||||
self.assertIn("Success!", out)
|
self.assertIn("Success!", out)
|
||||||
@ -88,16 +88,16 @@ class RegisterTestCase(TestCase):
|
|||||||
out = []
|
out = []
|
||||||
err_code = []
|
err_code = []
|
||||||
|
|
||||||
request_registration(
|
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
|
||||||
"user",
|
request_registration(
|
||||||
"pass",
|
"user",
|
||||||
"matrix.org",
|
"pass",
|
||||||
"shared",
|
"matrix.org",
|
||||||
admin=False,
|
"shared",
|
||||||
requests=requests,
|
admin=False,
|
||||||
_print=out.append,
|
_print=out.append,
|
||||||
exit=err_code.append,
|
exit=err_code.append,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit was called
|
# Exit was called
|
||||||
self.assertEqual(err_code, [1])
|
self.assertEqual(err_code, [1])
|
||||||
@ -140,16 +140,16 @@ class RegisterTestCase(TestCase):
|
|||||||
out = []
|
out = []
|
||||||
err_code = []
|
err_code = []
|
||||||
|
|
||||||
request_registration(
|
with patch("synapse._scripts.register_new_matrix_user.requests", requests):
|
||||||
"user",
|
request_registration(
|
||||||
"pass",
|
"user",
|
||||||
"matrix.org",
|
"pass",
|
||||||
"shared",
|
"matrix.org",
|
||||||
admin=False,
|
"shared",
|
||||||
requests=requests,
|
admin=False,
|
||||||
_print=out.append,
|
_print=out.append,
|
||||||
exit=err_code.append,
|
exit=err_code.append,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exit was called
|
# Exit was called
|
||||||
self.assertEqual(err_code, [1])
|
self.assertEqual(err_code, [1])
|
||||||
|
Loading…
Reference in New Issue
Block a user