mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-01 11:24:55 -05:00
Merge remote-tracking branch 'upstream/release-v1.57'
This commit is contained in:
commit
b2fa6ec9f6
248 changed files with 14616 additions and 8934 deletions
|
|
@ -68,7 +68,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.56.0"
|
||||
__version__ = "1.57.1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
|||
|
|
@ -15,19 +15,19 @@
|
|||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import NoReturn, Optional
|
||||
|
||||
import nacl.signing
|
||||
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
||||
from signedjson.types import VerifyKey
|
||||
|
||||
|
||||
def exit(status: int = 0, message: Optional[str] = None):
|
||||
def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
|
||||
if message:
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(status)
|
||||
|
||||
|
||||
def format_plain(public_key: nacl.signing.VerifyKey):
|
||||
def format_plain(public_key: VerifyKey) -> None:
|
||||
print(
|
||||
"%s:%s %s"
|
||||
% (
|
||||
|
|
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
|
|||
)
|
||||
|
||||
|
||||
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
|
||||
def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
|
||||
print(
|
||||
' "%s:%s": { key: "%s", expired_ts: %i }'
|
||||
% (
|
||||
|
|
@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -94,7 +94,6 @@ def main():
|
|||
message="Error reading key from file %s: %s %s"
|
||||
% (file.name, type(e), e),
|
||||
)
|
||||
res = []
|
||||
for key in res:
|
||||
formatter(get_verify_key(key))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config-dir",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import sys
|
|||
from synapse.config.logger import DEFAULT_LOG_CONFIG
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
|
|||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import bcrypt
|
|||
import yaml
|
||||
|
||||
|
||||
def prompt_for_pass():
|
||||
def prompt_for_pass() -> str:
|
||||
password = getpass.getpass("Password: ")
|
||||
|
||||
if not password:
|
||||
|
|
@ -23,7 +23,7 @@ def prompt_for_pass():
|
|||
return password
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
bcrypt_rounds = 12
|
||||
password_pepper = ""
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main(src_repo, dest_repo):
|
||||
def main(src_repo: str, dest_repo: str) -> None:
|
||||
src_paths = MediaFilePaths(src_repo)
|
||||
dest_paths = MediaFilePaths(dest_repo)
|
||||
for line in sys.stdin:
|
||||
|
|
@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
|
|||
move_media(parts[0], parts[1], src_paths, dest_paths)
|
||||
|
||||
|
||||
def move_media(origin_server, file_id, src_paths, dest_paths):
|
||||
def move_media(
|
||||
origin_server: str,
|
||||
file_id: str,
|
||||
src_paths: MediaFilePaths,
|
||||
dest_paths: MediaFilePaths,
|
||||
) -> None:
|
||||
"""Move the given file, and any thumbnails, to the dest repo
|
||||
|
||||
Args:
|
||||
origin_server (str):
|
||||
file_id (str):
|
||||
src_paths (MediaFilePaths):
|
||||
dest_paths (MediaFilePaths):
|
||||
origin_server:
|
||||
file_id:
|
||||
src_paths:
|
||||
dest_paths:
|
||||
"""
|
||||
logger.info("%s/%s", origin_server, file_id)
|
||||
|
||||
|
|
@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
|
|||
)
|
||||
|
||||
|
||||
def mkdir_and_move(original_file, dest_file):
|
||||
def mkdir_and_move(original_file: str, dest_file: str) -> None:
|
||||
dirname = os.path.dirname(dest_file)
|
||||
if not os.path.exists(dirname):
|
||||
logger.debug("mkdir %s", dirname)
|
||||
|
|
@ -109,10 +114,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("dest_repo", help="Path to source content repo")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
main(args.src_repo, args.dest_repo)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import logging
|
|||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import requests as _requests
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
|
||||
|
|
@ -33,7 +33,6 @@ def request_registration(
|
|||
shared_secret: str,
|
||||
admin: bool = False,
|
||||
user_type: Optional[str] = None,
|
||||
requests=_requests,
|
||||
_print: Callable[[str], None] = print,
|
||||
exit: Callable[[int], None] = sys.exit,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -138,9 +138,7 @@ def main() -> None:
|
|||
config_args = parser.parse_args(sys.argv[1:])
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
config_dict = read_config_files(config_files)
|
||||
config.parse_config_dict(
|
||||
config_dict,
|
||||
)
|
||||
config.parse_config_dict(config_dict, "", "")
|
||||
|
||||
since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
|
||||
exclude_users_with_email = config_args.exclude_emails
|
||||
|
|
|
|||
|
|
@ -21,12 +21,29 @@ import logging
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Iterable, Optional, Set
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import yaml
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor as reactor_
|
||||
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
|
@ -35,7 +52,7 @@ from synapse.logging.context import (
|
|||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.storage.database import DatabasePool, make_conn
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||
from synapse.storage.databases.main import PushRuleStore
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||
|
|
@ -66,8 +83,12 @@ from synapse.storage.databases.main.user_directory import (
|
|||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
|
||||
# Cast safety: Twisted does some naughty magic which replaces the
|
||||
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||
reactor = cast(ISynapseReactor, reactor_)
|
||||
logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
|
||||
|
|
@ -97,6 +118,7 @@ BOOLEAN_COLUMNS = {
|
|||
"users": ["shadow_banned"],
|
||||
"e2e_fallback_keys_json": ["used"],
|
||||
"access_tokens": ["used"],
|
||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -158,12 +180,16 @@ IGNORED_TABLES = {
|
|||
|
||||
# Error returned by the run function. Used at the top-level part of the script to
|
||||
# handle errors and return codes.
|
||||
end_error = None # type: Optional[str]
|
||||
end_error: Optional[str] = None
|
||||
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
||||
# will show only the error message without the stacktrace, if exec_info is defined but
|
||||
# not the error then the script will show nothing outside of what's printed in the run
|
||||
# function. If both are defined, the script will print both the error and the stacktrace.
|
||||
end_error_exec_info = None
|
||||
end_error_exec_info: Optional[
|
||||
Tuple[Type[BaseException], BaseException, TracebackType]
|
||||
] = None
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class Store(
|
||||
|
|
@ -187,17 +213,19 @@ class Store(
|
|||
PresenceBackgroundUpdateStore,
|
||||
GroupServerWorkerStore,
|
||||
):
|
||||
def execute(self, f, *args, **kwargs):
|
||||
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, *args):
|
||||
def r(txn):
|
||||
def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
|
||||
def r(txn: LoggingTransaction) -> List[Tuple]:
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db_pool.runInteraction("execute_sql", r)
|
||||
|
||||
def insert_many_txn(self, txn, table, headers, rows):
|
||||
def insert_many_txn(
|
||||
self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
|
||||
) -> None:
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in headers),
|
||||
|
|
@ -210,14 +238,15 @@ class Store(
|
|||
logger.exception("Failed to insert: %s", table)
|
||||
raise
|
||||
|
||||
def set_room_is_public(self, room_id, is_public):
|
||||
# Note: the parent method is an `async def`.
|
||||
def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
|
||||
raise Exception(
|
||||
"Attempt to set room_is_public during port_db: database not empty?"
|
||||
)
|
||||
|
||||
|
||||
class MockHomeserver:
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: HomeServerConfig):
|
||||
self.clock = Clock(reactor)
|
||||
self.config = config
|
||||
self.hostname = config.server.server_name
|
||||
|
|
@ -225,21 +254,30 @@ class MockHomeserver:
|
|||
"matrix-synapse"
|
||||
)
|
||||
|
||||
def get_clock(self):
|
||||
def get_clock(self) -> Clock:
|
||||
return self.clock
|
||||
|
||||
def get_reactor(self):
|
||||
def get_reactor(self) -> ISynapseReactor:
|
||||
return reactor
|
||||
|
||||
def get_instance_name(self):
|
||||
def get_instance_name(self) -> str:
|
||||
return "master"
|
||||
|
||||
|
||||
class Porter(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
class Porter:
|
||||
def __init__(
|
||||
self,
|
||||
sqlite_config: Dict[str, Any],
|
||||
progress: "Progress",
|
||||
batch_size: int,
|
||||
hs_config: HomeServerConfig,
|
||||
):
|
||||
self.sqlite_config = sqlite_config
|
||||
self.progress = progress
|
||||
self.batch_size = batch_size
|
||||
self.hs_config = hs_config
|
||||
|
||||
async def setup_table(self, table):
|
||||
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
# It's safe to just carry on inserting.
|
||||
row = await self.postgres_store.db_pool.simple_select_one(
|
||||
|
|
@ -281,7 +319,7 @@ class Porter(object):
|
|||
)
|
||||
else:
|
||||
|
||||
def delete_all(txn):
|
||||
def delete_all(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
|
||||
)
|
||||
|
|
@ -306,7 +344,7 @@ class Porter(object):
|
|||
async def get_table_constraints(self) -> Dict[str, Set[str]]:
|
||||
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
|
||||
|
||||
def _get_constraints(txn):
|
||||
def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
|
||||
# We can pull the information about foreign key constraints out from
|
||||
# the postgres schema tables.
|
||||
sql = """
|
||||
|
|
@ -322,7 +360,7 @@ class Porter(object):
|
|||
"""
|
||||
txn.execute(sql)
|
||||
|
||||
results = {}
|
||||
results: Dict[str, Set[str]] = {}
|
||||
for table, foreign_table in txn:
|
||||
results.setdefault(table, set()).add(foreign_table)
|
||||
return results
|
||||
|
|
@ -332,8 +370,13 @@ class Porter(object):
|
|||
)
|
||||
|
||||
async def handle_table(
|
||||
self, table, postgres_size, table_size, forward_chunk, backward_chunk
|
||||
):
|
||||
self,
|
||||
table: str,
|
||||
postgres_size: int,
|
||||
table_size: int,
|
||||
forward_chunk: int,
|
||||
backward_chunk: int,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Table %s: %i/%i (rows %i-%i) already ported",
|
||||
table,
|
||||
|
|
@ -380,7 +423,9 @@ class Porter(object):
|
|||
|
||||
while True:
|
||||
|
||||
def r(txn):
|
||||
def r(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
|
||||
forward_rows = []
|
||||
backward_rows = []
|
||||
if do_forward[0]:
|
||||
|
|
@ -407,6 +452,7 @@ class Porter(object):
|
|||
)
|
||||
|
||||
if frows or brows:
|
||||
assert headers is not None
|
||||
if frows:
|
||||
forward_chunk = max(row[0] for row in frows) + 1
|
||||
if brows:
|
||||
|
|
@ -415,7 +461,8 @@ class Porter(object):
|
|||
rows = frows + brows
|
||||
rows = self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
def insert(txn: LoggingTransaction) -> None:
|
||||
assert headers is not None
|
||||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
|
||||
|
||||
self.postgres_store.db_pool.simple_update_one_txn(
|
||||
|
|
@ -437,8 +484,12 @@ class Porter(object):
|
|||
return
|
||||
|
||||
async def handle_search_table(
|
||||
self, postgres_size, table_size, forward_chunk, backward_chunk
|
||||
):
|
||||
self,
|
||||
postgres_size: int,
|
||||
table_size: int,
|
||||
forward_chunk: int,
|
||||
backward_chunk: int,
|
||||
) -> None:
|
||||
select = (
|
||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||
" FROM event_search as es"
|
||||
|
|
@ -449,7 +500,7 @@ class Porter(object):
|
|||
|
||||
while True:
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
|
||||
txn.execute(select, (forward_chunk, self.batch_size))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
|
@ -463,7 +514,7 @@ class Porter(object):
|
|||
|
||||
# We have to treat event_search differently since it has a
|
||||
# different structure in the two different databases.
|
||||
def insert(txn):
|
||||
def insert(txn: LoggingTransaction) -> None:
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key,"
|
||||
" sender, vector, origin_server_ts, stream_ordering)"
|
||||
|
|
@ -517,7 +568,7 @@ class Porter(object):
|
|||
self,
|
||||
db_config: DatabaseConnectionConfig,
|
||||
allow_outdated_version: bool = False,
|
||||
):
|
||||
) -> Store:
|
||||
"""Builds and returns a database store using the provided configuration.
|
||||
|
||||
Args:
|
||||
|
|
@ -539,12 +590,13 @@ class Porter(object):
|
|||
db_conn, allow_outdated_version=allow_outdated_version
|
||||
)
|
||||
prepare_database(db_conn, engine, config=self.hs_config)
|
||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
|
||||
# Type safety: ignore that we're using Mock homeservers here.
|
||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
|
||||
db_conn.commit()
|
||||
|
||||
return store
|
||||
|
||||
async def run_background_updates_on_postgres(self):
|
||||
async def run_background_updates_on_postgres(self) -> None:
|
||||
# Manually apply all background updates on the PostgreSQL database.
|
||||
postgres_ready = (
|
||||
await self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||
|
|
@ -556,12 +608,12 @@ class Porter(object):
|
|||
self.progress.set_state("Running background updates on PostgreSQL")
|
||||
|
||||
while not postgres_ready:
|
||||
await self.postgres_store.db_pool.updates.do_next_background_update(100)
|
||||
await self.postgres_store.db_pool.updates.do_next_background_update(True)
|
||||
postgres_ready = await (
|
||||
self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
"""Ports the SQLite database to a PostgreSQL database.
|
||||
|
||||
When a fatal error is met, its message is assigned to the global "end_error"
|
||||
|
|
@ -597,7 +649,7 @@ class Porter(object):
|
|||
|
||||
self.progress.set_state("Creating port tables")
|
||||
|
||||
def create_port_table(txn):
|
||||
def create_port_table(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
|
||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||
|
|
@ -610,7 +662,7 @@ class Porter(object):
|
|||
# We want people to be able to rerun this script from an old port
|
||||
# so that they can pick up any missing events that were not
|
||||
# ported across.
|
||||
def alter_table(txn):
|
||||
def alter_table(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||
" RENAME rowid TO forward_rowid"
|
||||
|
|
@ -723,12 +775,16 @@ class Porter(object):
|
|||
except Exception as e:
|
||||
global end_error_exec_info
|
||||
end_error = str(e)
|
||||
end_error_exec_info = sys.exc_info()
|
||||
# Type safety: we're in an exception handler, so the exc_info() tuple
|
||||
# will not be (None, None, None).
|
||||
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
|
||||
logger.exception("")
|
||||
finally:
|
||||
reactor.stop()
|
||||
|
||||
def _convert_rows(self, table, headers, rows):
|
||||
def _convert_rows(
|
||||
self, table: str, headers: List[str], rows: List[Tuple]
|
||||
) -> List[Tuple]:
|
||||
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
|
||||
|
||||
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
|
||||
|
|
@ -736,7 +792,7 @@ class Porter(object):
|
|||
class BadValueException(Exception):
|
||||
pass
|
||||
|
||||
def conv(j, col):
|
||||
def conv(j: int, col: object) -> object:
|
||||
if j in bool_cols:
|
||||
return bool(col)
|
||||
if isinstance(col, bytes):
|
||||
|
|
@ -762,7 +818,7 @@ class Porter(object):
|
|||
|
||||
return outrows
|
||||
|
||||
async def _setup_sent_transactions(self):
|
||||
async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
|
||||
# Only save things from the last day
|
||||
yesterday = int(time.time() * 1000) - 86400000
|
||||
|
||||
|
|
@ -774,10 +830,10 @@ class Porter(object):
|
|||
")"
|
||||
)
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
|
||||
txn.execute(select)
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
headers: List[str] = [column[0] for column in txn.description]
|
||||
|
||||
ts_ind = headers.index("ts")
|
||||
|
||||
|
|
@ -791,7 +847,7 @@ class Porter(object):
|
|||
if inserted_rows:
|
||||
max_inserted_rowid = max(r[0] for r in rows)
|
||||
|
||||
def insert(txn):
|
||||
def insert(txn: LoggingTransaction) -> None:
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, "sent_transactions", headers[1:], rows
|
||||
)
|
||||
|
|
@ -800,7 +856,7 @@ class Porter(object):
|
|||
else:
|
||||
max_inserted_rowid = 0
|
||||
|
||||
def get_start_id(txn):
|
||||
def get_start_id(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
|
||||
" ORDER BY rowid ASC LIMIT 1",
|
||||
|
|
@ -825,12 +881,13 @@ class Porter(object):
|
|||
},
|
||||
)
|
||||
|
||||
def get_sent_table_size(txn):
|
||||
def get_sent_table_size(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
|
||||
)
|
||||
(size,) = txn.fetchone()
|
||||
return int(size)
|
||||
result = txn.fetchone()
|
||||
assert result is not None
|
||||
return int(result[0])
|
||||
|
||||
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
|
||||
|
||||
|
|
@ -838,25 +895,35 @@ class Porter(object):
|
|||
|
||||
return next_chunk, inserted_rows, total_count
|
||||
|
||||
async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
frows = await self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
|
||||
async def _get_remaining_count_to_port(
|
||||
self, table: str, forward_chunk: int, backward_chunk: int
|
||||
) -> int:
|
||||
frows = cast(
|
||||
List[Tuple[int]],
|
||||
await self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
|
||||
),
|
||||
)
|
||||
|
||||
brows = await self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
|
||||
brows = cast(
|
||||
List[Tuple[int]],
|
||||
await self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
|
||||
),
|
||||
)
|
||||
|
||||
return frows[0][0] + brows[0][0]
|
||||
|
||||
async def _get_already_ported_count(self, table):
|
||||
async def _get_already_ported_count(self, table: str) -> int:
|
||||
rows = await self.postgres_store.execute_sql(
|
||||
"SELECT count(*) FROM %s" % (table,)
|
||||
)
|
||||
|
||||
return rows[0][0]
|
||||
|
||||
async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
async def _get_total_count_to_port(
|
||||
self, table: str, forward_chunk: int, backward_chunk: int
|
||||
) -> Tuple[int, int]:
|
||||
remaining, done = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
|
|
@ -877,14 +944,17 @@ class Porter(object):
|
|||
return done, remaining + done
|
||||
|
||||
async def _setup_state_group_id_seq(self) -> None:
|
||||
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
curr_id: Optional[
|
||||
int
|
||||
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
||||
)
|
||||
|
||||
if not curr_id:
|
||||
return
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> None:
|
||||
assert curr_id is not None
|
||||
next_id = curr_id + 1
|
||||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
|
||||
|
||||
|
|
@ -895,7 +965,7 @@ class Porter(object):
|
|||
"setup_user_id_seq", find_max_generated_user_id_localpart
|
||||
)
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> None:
|
||||
next_id = curr_id + 1
|
||||
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
||||
|
||||
|
|
@ -917,7 +987,7 @@ class Porter(object):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
def _setup_events_stream_seqs_set_pos(txn):
|
||||
def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
|
||||
if curr_forward_id:
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
|
||||
|
|
@ -941,17 +1011,20 @@ class Porter(object):
|
|||
"""Set a sequence to the correct value."""
|
||||
current_stream_ids = []
|
||||
for stream_id_table in stream_id_tables:
|
||||
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
table=stream_id_table,
|
||||
keyvalues={},
|
||||
retcol="COALESCE(MAX(stream_id), 1)",
|
||||
allow_none=True,
|
||||
max_stream_id = cast(
|
||||
int,
|
||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
table=stream_id_table,
|
||||
keyvalues={},
|
||||
retcol="COALESCE(MAX(stream_id), 1)",
|
||||
allow_none=True,
|
||||
),
|
||||
)
|
||||
current_stream_ids.append(max_stream_id)
|
||||
|
||||
next_id = max(current_stream_ids) + 1
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> None:
|
||||
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
|
||||
txn.execute(sql + " %s", (next_id,))
|
||||
|
||||
|
|
@ -960,14 +1033,18 @@ class Porter(object):
|
|||
)
|
||||
|
||||
async def _setup_auth_chain_sequence(self) -> None:
|
||||
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
curr_chain_id: Optional[
|
||||
int
|
||||
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
table="event_auth_chains",
|
||||
keyvalues={},
|
||||
retcol="MAX(chain_id)",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def r(txn):
|
||||
def r(txn: LoggingTransaction) -> None:
|
||||
# Presumably there is at least one row in event_auth_chains.
|
||||
assert curr_chain_id is not None
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
||||
(curr_chain_id + 1,),
|
||||
|
|
@ -985,15 +1062,22 @@ class Porter(object):
|
|||
##############################################
|
||||
|
||||
|
||||
class Progress(object):
|
||||
class TableProgress(TypedDict):
|
||||
start: int
|
||||
num_done: int
|
||||
total: int
|
||||
perc: int
|
||||
|
||||
|
||||
class Progress:
|
||||
"""Used to report progress of the port"""
|
||||
|
||||
def __init__(self):
|
||||
self.tables = {}
|
||||
def __init__(self) -> None:
|
||||
self.tables: Dict[str, TableProgress] = {}
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
def add_table(self, table, cur, size):
|
||||
def add_table(self, table: str, cur: int, size: int) -> None:
|
||||
self.tables[table] = {
|
||||
"start": cur,
|
||||
"num_done": cur,
|
||||
|
|
@ -1001,19 +1085,22 @@ class Progress(object):
|
|||
"perc": int(cur * 100 / size),
|
||||
}
|
||||
|
||||
def update(self, table, num_done):
|
||||
def update(self, table: str, num_done: int) -> None:
|
||||
data = self.tables[table]
|
||||
data["num_done"] = num_done
|
||||
data["perc"] = int(num_done * 100 / data["total"])
|
||||
|
||||
def done(self):
|
||||
def done(self) -> None:
|
||||
pass
|
||||
|
||||
def set_state(self, state: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CursesProgress(Progress):
|
||||
"""Reports progress to a curses window"""
|
||||
|
||||
def __init__(self, stdscr):
|
||||
def __init__(self, stdscr: "curses.window"):
|
||||
self.stdscr = stdscr
|
||||
|
||||
curses.use_default_colors()
|
||||
|
|
@ -1022,7 +1109,7 @@ class CursesProgress(Progress):
|
|||
curses.init_pair(1, curses.COLOR_RED, -1)
|
||||
curses.init_pair(2, curses.COLOR_GREEN, -1)
|
||||
|
||||
self.last_update = 0
|
||||
self.last_update = 0.0
|
||||
|
||||
self.finished = False
|
||||
|
||||
|
|
@ -1031,7 +1118,7 @@ class CursesProgress(Progress):
|
|||
|
||||
super(CursesProgress, self).__init__()
|
||||
|
||||
def update(self, table, num_done):
|
||||
def update(self, table: str, num_done: int) -> None:
|
||||
super(CursesProgress, self).update(table, num_done)
|
||||
|
||||
self.total_processed = 0
|
||||
|
|
@ -1042,7 +1129,7 @@ class CursesProgress(Progress):
|
|||
|
||||
self.render()
|
||||
|
||||
def render(self, force=False):
|
||||
def render(self, force: bool = False) -> None:
|
||||
now = time.time()
|
||||
|
||||
if not force and now - self.last_update < 0.2:
|
||||
|
|
@ -1081,8 +1168,7 @@ class CursesProgress(Progress):
|
|||
left_margin = 5
|
||||
middle_space = 1
|
||||
|
||||
items = self.tables.items()
|
||||
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
|
||||
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
|
||||
|
||||
for i, (table, data) in enumerate(items):
|
||||
if i + 2 >= rows:
|
||||
|
|
@ -1115,12 +1201,12 @@ class CursesProgress(Progress):
|
|||
self.stdscr.refresh()
|
||||
self.last_update = time.time()
|
||||
|
||||
def done(self):
|
||||
def done(self) -> None:
|
||||
self.finished = True
|
||||
self.render(True)
|
||||
self.stdscr.getch()
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: str) -> None:
|
||||
self.stdscr.clear()
|
||||
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
|
||||
self.stdscr.refresh()
|
||||
|
|
@ -1129,7 +1215,7 @@ class CursesProgress(Progress):
|
|||
class TerminalProgress(Progress):
|
||||
"""Just prints progress to the terminal"""
|
||||
|
||||
def update(self, table, num_done):
|
||||
def update(self, table: str, num_done: int) -> None:
|
||||
super(TerminalProgress, self).update(table, num_done)
|
||||
|
||||
data = self.tables[table]
|
||||
|
|
@ -1138,7 +1224,7 @@ class TerminalProgress(Progress):
|
|||
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
|
||||
)
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: str) -> None:
|
||||
print(state + "...")
|
||||
|
||||
|
||||
|
|
@ -1146,7 +1232,7 @@ class TerminalProgress(Progress):
|
|||
##############################################
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A script to port an existing synapse SQLite database to"
|
||||
" a new PostgreSQL database."
|
||||
|
|
@ -1178,15 +1264,11 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
|
||||
if args.curses:
|
||||
logging_config["filename"] = "port-synapse.log"
|
||||
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
filename="port-synapse.log" if args.curses else None,
|
||||
)
|
||||
|
||||
sqlite_config = {
|
||||
"name": "sqlite3",
|
||||
|
|
@ -1216,7 +1298,8 @@ def main():
|
|||
config = HomeServerConfig()
|
||||
config.parse_config_dict(hs_config, "", "")
|
||||
|
||||
def start(stdscr=None):
|
||||
def start(stdscr: Optional["curses.window"] = None) -> None:
|
||||
progress: Progress
|
||||
if stdscr:
|
||||
progress = CursesProgress(stdscr)
|
||||
else:
|
||||
|
|
@ -1230,7 +1313,7 @@ def main():
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run():
|
||||
def run() -> Generator["defer.Deferred[Any]", Any, None]:
|
||||
with LoggingContext("synapse_port_db_run"):
|
||||
yield defer.ensureDeferred(porter.run())
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
from typing import Iterable, NoReturn, Optional, TextIO
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ one of the following:
|
|||
--------------------------------------------------------------------------------"""
|
||||
|
||||
|
||||
def pid_running(pid):
|
||||
def pid_running(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError as err:
|
||||
|
|
@ -68,7 +68,7 @@ def pid_running(pid):
|
|||
return True
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
|
||||
# Lets check if we're writing to a TTY before colouring
|
||||
should_colour = False
|
||||
try:
|
||||
|
|
@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
|||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def abort(message, colour=RED, stream=sys.stderr):
|
||||
def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
|
||||
write(message, colour, stream)
|
||||
sys.exit(1)
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ Worker = collections.namedtuple(
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,42 +16,47 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor as reactor_
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.types import ISynapseReactor
|
||||
|
||||
# Cast safety: Twisted does some naughty magic which replaces the
|
||||
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||
reactor = cast(ISynapseReactor, reactor_)
|
||||
logger = logging.getLogger("update_database")
|
||||
|
||||
|
||||
class MockHomeserver(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config: HomeServerConfig):
|
||||
super(MockHomeserver, self).__init__(
|
||||
config.server.server_name, reactor=reactor, config=config, **kwargs
|
||||
)
|
||||
|
||||
self.version_string = "Synapse/" + get_distribution_version_string(
|
||||
"matrix-synapse"
|
||||
hostname=config.server.server_name,
|
||||
config=config,
|
||||
reactor=reactor,
|
||||
version_string="Synapse/"
|
||||
+ get_distribution_version_string("matrix-synapse"),
|
||||
)
|
||||
|
||||
|
||||
def run_background_updates(hs):
|
||||
def run_background_updates(hs: HomeServer) -> None:
|
||||
store = hs.get_datastores().main
|
||||
|
||||
async def run_background_updates():
|
||||
async def run_background_updates() -> None:
|
||||
await store.db_pool.updates.run_background_updates(sleep=False)
|
||||
# Stop the reactor to exit the script once every background update is run.
|
||||
reactor.stop()
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
# Apply all background updates on the database.
|
||||
defer.ensureDeferred(
|
||||
run_as_background_process("background_updates", run_background_updates)
|
||||
|
|
@ -62,7 +67,7 @@ def run_background_updates(hs):
|
|||
reactor.run()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Updates a synapse database to the latest schema and optionally runs background updates"
|
||||
|
|
@ -85,12 +90,10 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# Load, process and sanity-check the config.
|
||||
hs_config = yaml.safe_load(args.database_config)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ def start_reactor(
|
|||
appname: str,
|
||||
soft_file_limit: int,
|
||||
gc_thresholds: Optional[Tuple[int, int, int]],
|
||||
pid_file: str,
|
||||
pid_file: Optional[str],
|
||||
daemonize: bool,
|
||||
print_pidfile: bool,
|
||||
logger: logging.Logger,
|
||||
|
|
@ -171,6 +171,8 @@ def start_reactor(
|
|||
# appearing to go backwards.
|
||||
with PreserveLoggingContext():
|
||||
if daemonize:
|
||||
assert pid_file is not None
|
||||
|
||||
if print_pidfile:
|
||||
print(pid_file)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter
|
|||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
|
|
@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
|
|||
SlavedDeviceStore,
|
||||
SlavedPushRuleStore,
|
||||
SlavedEventStore,
|
||||
SlavedClientIpStore,
|
||||
BaseSlavedStore,
|
||||
RoomWorkerStore,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
|||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
|
|
@ -247,7 +246,6 @@ class GenericWorkerSlavedStore(
|
|||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedProfileStore,
|
||||
SlavedClientIpStore,
|
||||
SlavedFilteringStore,
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
MediaRepositoryStore,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -22,7 +23,13 @@ from netaddr import IPSet
|
|||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
GroupID,
|
||||
JsonDict,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -400,6 +407,7 @@ class AppServiceTransaction:
|
|||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceListUpdates,
|
||||
):
|
||||
self.service = service
|
||||
self.id = id
|
||||
|
|
@ -408,6 +416,7 @@ class AppServiceTransaction:
|
|||
self.to_device_messages = to_device_messages
|
||||
self.one_time_key_counts = one_time_key_counts
|
||||
self.unused_fallback_keys = unused_fallback_keys
|
||||
self.device_list_summary = device_list_summary
|
||||
|
||||
async def send(self, as_api: "ApplicationServiceApi") -> bool:
|
||||
"""Sends this transaction using the provided AS API interface.
|
||||
|
|
@ -424,6 +433,7 @@ class AppServiceTransaction:
|
|||
to_device_messages=self.to_device_messages,
|
||||
one_time_key_counts=self.one_time_key_counts,
|
||||
unused_fallback_keys=self.unused_fallback_keys,
|
||||
device_list_summary=self.device_list_summary,
|
||||
txn_id=self.id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -27,7 +28,7 @@ from synapse.appservice import (
|
|||
from synapse.events import EventBase
|
||||
from synapse.events.utils import SerializeEventConfig, serialize_event
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceListUpdates,
|
||||
txn_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
|
|
@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
}
|
||||
)
|
||||
|
||||
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
|
||||
if service.msc3202_transaction_extensions:
|
||||
if one_time_key_counts:
|
||||
body[
|
||||
|
|
@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
body[
|
||||
"org.matrix.msc3202.device_unused_fallback_keys"
|
||||
] = unused_fallback_keys
|
||||
if device_list_summary:
|
||||
body["org.matrix.msc3202.device_lists"] = {
|
||||
"changed": list(device_list_summary.changed),
|
||||
"left": list(device_list_summary.left),
|
||||
}
|
||||
|
||||
try:
|
||||
await self.put_json(
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ from synapse.events import EventBase
|
|||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import DeviceListUpdates, JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
|
|||
events: Optional[Collection[EventBase]] = None,
|
||||
ephemeral: Optional[Collection[JsonDict]] = None,
|
||||
to_device_messages: Optional[Collection[JsonDict]] = None,
|
||||
device_list_summary: Optional[DeviceListUpdates] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue some data to be sent off to an application service.
|
||||
|
|
@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
|
|||
to_device_messages: The to-device messages to send. These differ from normal
|
||||
to-device messages sent to clients, as they have 'to_device_id' and
|
||||
'to_user_id' fields.
|
||||
device_list_summary: A summary of users that the application service either needs
|
||||
to refresh the device lists of, or those that the application service need no
|
||||
longer track the device lists of.
|
||||
"""
|
||||
# We purposefully allow this method to run with empty events/ephemeral
|
||||
# collections, so that callers do not need to check iterable size themselves.
|
||||
if not events and not ephemeral and not to_device_messages:
|
||||
if (
|
||||
not events
|
||||
and not ephemeral
|
||||
and not to_device_messages
|
||||
and not device_list_summary
|
||||
):
|
||||
return
|
||||
|
||||
if events:
|
||||
|
|
@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
|
|||
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
|
||||
to_device_messages
|
||||
)
|
||||
if device_list_summary:
|
||||
self.queuer.queued_device_list_summaries.setdefault(
|
||||
appservice.id, []
|
||||
).append(device_list_summary)
|
||||
|
||||
# Kick off a new application service transaction
|
||||
self.queuer.start_background_request(appservice)
|
||||
|
|
@ -169,6 +182,8 @@ class _ServiceQueuer:
|
|||
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
|
||||
# dict of {service_id: [to_device_message_json]}
|
||||
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
|
||||
# dict of {service_id: [device_list_summary]}
|
||||
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
|
||||
|
||||
# the appservices which currently have a transaction in flight
|
||||
self.requests_in_flight: Set[str] = set()
|
||||
|
|
@ -212,7 +227,35 @@ class _ServiceQueuer:
|
|||
]
|
||||
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
|
||||
|
||||
if not events and not ephemeral and not to_device_messages_to_send:
|
||||
# Consolidate any pending device list summaries into a single, up-to-date
|
||||
# summary.
|
||||
# Note: this code assumes that in a single DeviceListUpdates, a user will
|
||||
# never be in both "changed" and "left" sets.
|
||||
device_list_summary = DeviceListUpdates()
|
||||
for summary in self.queued_device_list_summaries.get(service.id, []):
|
||||
# For every user in the incoming "changed" set:
|
||||
# * Remove them from the existing "left" set if necessary
|
||||
# (as we need to start tracking them again)
|
||||
# * Add them to the existing "changed" set if necessary.
|
||||
device_list_summary.left.difference_update(summary.changed)
|
||||
device_list_summary.changed.update(summary.changed)
|
||||
|
||||
# For every user in the incoming "left" set:
|
||||
# * Remove them from the existing "changed" set if necessary
|
||||
# (we no longer need to track them)
|
||||
# * Add them to the existing "left" set if necessary.
|
||||
device_list_summary.changed.difference_update(summary.left)
|
||||
device_list_summary.left.update(summary.left)
|
||||
self.queued_device_list_summaries.clear()
|
||||
|
||||
if (
|
||||
not events
|
||||
and not ephemeral
|
||||
and not to_device_messages_to_send
|
||||
# DeviceListUpdates is True if either the 'changed' or 'left' sets have
|
||||
# at least one entry, otherwise False
|
||||
and not device_list_summary
|
||||
):
|
||||
return
|
||||
|
||||
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
|
||||
|
|
@ -240,6 +283,7 @@ class _ServiceQueuer:
|
|||
to_device_messages_to_send,
|
||||
one_time_key_counts,
|
||||
unused_fallback_keys,
|
||||
device_list_summary,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("AS request failed")
|
||||
|
|
@ -322,6 +366,7 @@ class _TransactionController:
|
|||
to_device_messages: Optional[List[JsonDict]] = None,
|
||||
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
|
||||
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
|
||||
device_list_summary: Optional[DeviceListUpdates] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a transaction with the given data and send to the provided
|
||||
|
|
@ -336,6 +381,7 @@ class _TransactionController:
|
|||
appservice devices in the transaction.
|
||||
unused_fallback_keys: Lists of unused fallback keys for relevant
|
||||
appservice devices in the transaction.
|
||||
device_list_summary: The device list summary to include in the transaction.
|
||||
"""
|
||||
try:
|
||||
txn = await self.store.create_appservice_txn(
|
||||
|
|
@ -345,6 +391,7 @@ class _TransactionController:
|
|||
to_device_messages=to_device_messages or [],
|
||||
one_time_key_counts=one_time_key_counts or {},
|
||||
unused_fallback_keys=unused_fallback_keys or {},
|
||||
device_list_summary=device_list_summary or DeviceListUpdates(),
|
||||
)
|
||||
service_is_up = await self._is_service_up(service)
|
||||
if service_is_up:
|
||||
|
|
|
|||
|
|
@ -702,10 +702,7 @@ class RootConfig:
|
|||
return obj
|
||||
|
||||
def parse_config_dict(
|
||||
self,
|
||||
config_dict: Dict[str, Any],
|
||||
config_dir_path: Optional[str] = None,
|
||||
data_dir_path: Optional[str] = None,
|
||||
self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
|
||||
) -> None:
|
||||
"""Read the information from the config dict into this Config object.
|
||||
|
||||
|
|
|
|||
|
|
@ -126,10 +126,7 @@ class RootConfig:
|
|||
@classmethod
|
||||
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
|
||||
def parse_config_dict(
|
||||
self,
|
||||
config_dict: Dict[str, Any],
|
||||
config_dir_path: Optional[str] = ...,
|
||||
data_dir_path: Optional[str] = ...,
|
||||
self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
|
||||
) -> None: ...
|
||||
def generate_config(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,7 +31,7 @@ https://matrix-org.github.io/synapse/latest/templates.html
|
|||
class AccountValidityConfig(Config):
|
||||
section = "account_validity"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
"""Parses the old account validity config. The config format looks like this:
|
||||
|
||||
account_validity:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable
|
||||
from typing import Any, Iterable
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.config._base import Config, ConfigError
|
||||
|
|
@ -26,12 +26,12 @@ logger = logging.getLogger(__name__)
|
|||
class ApiConfig(Config):
|
||||
section = "api"
|
||||
|
||||
def read_config(self, config: JsonDict, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
validate_config(_MAIN_SCHEMA, config, ())
|
||||
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
|
||||
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
|
||||
|
||||
def generate_config_section(cls, **kwargs) -> str:
|
||||
def generate_config_section(cls, **kwargs: Any) -> str:
|
||||
formatted_default_state_types = "\n".join(
|
||||
" # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import yaml
|
||||
|
|
@ -31,12 +31,12 @@ logger = logging.getLogger(__name__)
|
|||
class AppServiceConfig(Config):
|
||||
section = "appservice"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||
self.notify_appservices = config.get("notify_appservices", True)
|
||||
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
|
||||
|
||||
def generate_config_section(cls, **kwargs) -> str:
|
||||
def generate_config_section(cls, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# A list of application service config files to use
|
||||
#
|
||||
|
|
@ -170,6 +170,7 @@ def _load_appservice(
|
|||
# When enabled, appservice transactions contain the following information:
|
||||
# - device One-Time Key counts
|
||||
# - device unused fallback key usage states
|
||||
# - device list changes
|
||||
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
|
||||
if not isinstance(msc3202_transaction_extensions, bool):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
|
@ -21,7 +24,7 @@ class AuthConfig(Config):
|
|||
|
||||
section = "auth"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
password_config = config.get("password_config", {})
|
||||
if password_config is None:
|
||||
password_config = {}
|
||||
|
|
@ -40,7 +43,7 @@ class AuthConfig(Config):
|
|||
ui_auth.get("session_timeout", 0)
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
password_config:
|
||||
# Uncomment to disable password login
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
|
@ -18,7 +21,7 @@ from ._base import Config
|
|||
class BackgroundUpdateConfig(Config):
|
||||
section = "background_updates"
|
||||
|
||||
def generate_config_section(self, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Background Updates ##
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ class BackgroundUpdateConfig(Config):
|
|||
#default_batch_size: 50
|
||||
"""
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
bg_update_config = config.get("background_updates") or {}
|
||||
|
||||
self.update_duration_ms = bg_update_config.get(
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.check_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
|
@ -105,7 +106,7 @@ class CacheConfig(Config):
|
|||
with _CACHES_LOCK:
|
||||
_CACHES.clear()
|
||||
|
||||
def generate_config_section(self, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Caching ##
|
||||
|
||||
|
|
@ -172,7 +173,7 @@ class CacheConfig(Config):
|
|||
#sync_response_cache_duration: 2m
|
||||
"""
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.event_cache_size = self.parse_size(
|
||||
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,15 +12,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
class CaptchaConfig(Config):
|
||||
section = "captcha"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.recaptcha_private_key = config.get("recaptcha_private_key")
|
||||
self.recaptcha_public_key = config.get("recaptcha_public_key")
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
recaptcha_private_key = config.get("recaptcha_private_key")
|
||||
if recaptcha_private_key is not None and not isinstance(
|
||||
recaptcha_private_key, str
|
||||
):
|
||||
raise ConfigError("recaptcha_private_key must be a string.")
|
||||
self.recaptcha_private_key = recaptcha_private_key
|
||||
|
||||
recaptcha_public_key = config.get("recaptcha_public_key")
|
||||
if recaptcha_public_key is not None and not isinstance(
|
||||
recaptcha_public_key, str
|
||||
):
|
||||
raise ConfigError("recaptcha_public_key must be a string.")
|
||||
self.recaptcha_public_key = recaptcha_public_key
|
||||
|
||||
self.enable_registration_captcha = config.get(
|
||||
"enable_registration_captcha", False
|
||||
)
|
||||
|
|
@ -30,7 +46,7 @@ class CaptchaConfig(Config):
|
|||
)
|
||||
self.recaptcha_template = self.read_template("recaptcha.html")
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Captcha ##
|
||||
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
from typing import Any, List
|
||||
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
from ._util import validate_config
|
||||
|
|
@ -29,7 +30,7 @@ class CasConfig(Config):
|
|||
|
||||
section = "cas"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
cas_config = config.get("cas_config", None)
|
||||
self.cas_enabled = cas_config and cas_config.get("enabled", True)
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class CasConfig(Config):
|
|||
self.cas_displayname_attribute = None
|
||||
self.cas_required_attributes = []
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# Enable Central Authentication Service (CAS) for registration and login.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -13,9 +13,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
|
@ -76,18 +77,18 @@ class ConsentConfig(Config):
|
|||
|
||||
section = "consent"
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args: Any):
|
||||
super().__init__(*args)
|
||||
|
||||
self.user_consent_version: Optional[str] = None
|
||||
self.user_consent_template_dir: Optional[str] = None
|
||||
self.user_consent_server_notice_content = None
|
||||
self.user_consent_server_notice_content: Optional[JsonDict] = None
|
||||
self.user_consent_server_notice_to_guests = False
|
||||
self.block_events_without_consent_error = None
|
||||
self.block_events_without_consent_error: Optional[str] = None
|
||||
self.user_consent_at_registration = False
|
||||
self.user_consent_policy_name = "Privacy Policy"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
consent_config = config.get("user_consent")
|
||||
self.terms_template = self.read_template("terms.html")
|
||||
|
||||
|
|
@ -118,5 +119,5 @@ class ConsentConfig(Config):
|
|||
"policy_name", "Privacy Policy"
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return DEFAULT_CONFIG
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -121,12 +123,12 @@ class DatabaseConnectionConfig:
|
|||
class DatabaseConfig(Config):
|
||||
section = "database"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, *args: Any):
|
||||
super().__init__(*args)
|
||||
|
||||
self.databases = []
|
||||
self.databases: List[DatabaseConnectionConfig] = []
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
# We *experimentally* support specifying multiple databases via the
|
||||
# `databases` key. This is a map from a label to database config in the
|
||||
# same format as the `database` config option, plus an extra
|
||||
|
|
@ -170,7 +172,7 @@ class DatabaseConfig(Config):
|
|||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
self.set_databasepath(database_path)
|
||||
|
||||
def generate_config_section(self, data_dir_path, **kwargs) -> str:
|
||||
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
|
||||
return DEFAULT_CONFIG % {
|
||||
"database_path": os.path.join(data_dir_path, "homeserver.db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,12 @@ import email.utils
|
|||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -73,7 +76,7 @@ class EmailSubjectConfig:
|
|||
class EmailConfig(Config):
|
||||
section = "email"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
# TODO: We should separate better the email configuration from the notification
|
||||
# and account validity config.
|
||||
|
||||
|
|
@ -354,7 +357,7 @@ class EmailConfig(Config):
|
|||
path=("email", "invite_client_location"),
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return (
|
||||
"""\
|
||||
# Configuration for sending emails from Synapse.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
|
@ -21,13 +23,11 @@ class ExperimentalConfig(Config):
|
|||
|
||||
section = "experimental"
|
||||
|
||||
def read_config(self, config: JsonDict, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
experimental = config.get("experimental_features") or {}
|
||||
|
||||
# MSC3440 (thread relation)
|
||||
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
|
||||
# MSC3666: including bundled relations in /search.
|
||||
self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)
|
||||
|
||||
# MSC3026 (busy presence state)
|
||||
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
|
||||
|
|
@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
|
|||
"msc3202_device_masquerading", False
|
||||
)
|
||||
|
||||
# Portion of MSC3202 related to transaction extensions:
|
||||
# sending one-time key counts and fallback key usage to application services.
|
||||
# The portion of MSC3202 related to transaction extensions:
|
||||
# sending device list changes, one-time key counts and fallback key
|
||||
# usage to application services.
|
||||
self.msc3202_transaction_extensions: bool = experimental.get(
|
||||
"msc3202_transaction_extensions", False
|
||||
)
|
||||
|
|
@ -77,3 +78,6 @@ class ExperimentalConfig(Config):
|
|||
|
||||
# The deprecated groups feature.
|
||||
self.groups_enabled: bool = experimental.get("groups_enabled", True)
|
||||
|
||||
# MSC2654: Unread counts
|
||||
self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
|
||||
|
|
|
|||
|
|
@ -11,16 +11,17 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.config._util import validate_config
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
||||
class FederationConfig(Config):
|
||||
section = "federation"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist: Optional[dict] = None
|
||||
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
|
||||
|
|
@ -48,7 +49,7 @@ class FederationConfig(Config):
|
|||
"allow_device_name_lookup_over_federation", True
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Federation ##
|
||||
|
||||
|
|
|
|||
|
|
@ -12,17 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class GroupsConfig(Config):
|
||||
section = "groups"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.enable_group_creation = config.get("enable_group_creation", False)
|
||||
self.group_creation_prefix = config.get("group_creation_prefix", "")
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# Uncomment to allow non-server-admin users to create groups on this server
|
||||
#
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
MISSING_JWT = """Missing jwt library. This is required for jwt login.
|
||||
|
|
@ -24,7 +28,7 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
|
|||
class JWTConfig(Config):
|
||||
section = "jwt"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
jwt_config = config.get("jwt_config", None)
|
||||
if jwt_config:
|
||||
self.jwt_enabled = jwt_config.get("enabled", False)
|
||||
|
|
@ -52,7 +56,7 @@ class JWTConfig(Config):
|
|||
self.jwt_issuer = None
|
||||
self.jwt_audiences = None
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# JSON web token integration. The following settings can be used to make
|
||||
# Synapse JSON web tokens for authentication, instead of its internal
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
|
||||
|
||||
import attr
|
||||
import jsonschema
|
||||
|
|
@ -38,6 +38,9 @@ from synapse.util.stringutils import random_string, random_string_with_symbols
|
|||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from signedjson.key import VerifyKeyWithExpiry
|
||||
|
||||
INSECURE_NOTARY_ERROR = """\
|
||||
Your server is configured to accept key server responses without signature
|
||||
validation or TLS certificate validation. This is likely to be very insecure. If
|
||||
|
|
@ -96,11 +99,14 @@ class TrustedKeyServer:
|
|||
class KeyConfig(Config):
|
||||
section = "key"
|
||||
|
||||
def read_config(self, config, config_dir_path, **kwargs):
|
||||
def read_config(
|
||||
self, config: JsonDict, config_dir_path: str, **kwargs: Any
|
||||
) -> None:
|
||||
# the signing key can be specified inline or in a separate file
|
||||
if "signing_key" in config:
|
||||
self.signing_key = read_signing_keys([config["signing_key"]])
|
||||
else:
|
||||
assert config_dir_path is not None
|
||||
signing_key_path = config.get("signing_key_path")
|
||||
if signing_key_path is None:
|
||||
signing_key_path = os.path.join(
|
||||
|
|
@ -169,8 +175,12 @@ class KeyConfig(Config):
|
|||
self.form_secret = config.get("form_secret", None)
|
||||
|
||||
def generate_config_section(
|
||||
self, config_dir_path, server_name, generate_secrets=False, **kwargs
|
||||
):
|
||||
self,
|
||||
config_dir_path: str,
|
||||
server_name: str,
|
||||
generate_secrets: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
||||
if generate_secrets:
|
||||
|
|
@ -300,7 +310,7 @@ class KeyConfig(Config):
|
|||
|
||||
def read_old_signing_keys(
|
||||
self, old_signing_keys: Optional[JsonDict]
|
||||
) -> Dict[str, VerifyKey]:
|
||||
) -> Dict[str, "VerifyKeyWithExpiry"]:
|
||||
if old_signing_keys is None:
|
||||
return {}
|
||||
keys = {}
|
||||
|
|
@ -308,8 +318,8 @@ class KeyConfig(Config):
|
|||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
verify_key.expired_ts = key_data["expired_ts"]
|
||||
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment]
|
||||
verify_key.expired = key_data["expired_ts"]
|
||||
keys[key_id] = verify_key
|
||||
else:
|
||||
raise ConfigError(
|
||||
|
|
@ -422,7 +432,7 @@ def _parse_key_servers(
|
|||
server_name = server["server_name"]
|
||||
result = TrustedKeyServer(server_name=server_name)
|
||||
|
||||
verify_keys = server.get("verify_keys")
|
||||
verify_keys: Optional[Dict[str, str]] = server.get("verify_keys")
|
||||
if verify_keys is not None:
|
||||
result.verify_keys = {}
|
||||
for key_id, key_base64 in verify_keys.items():
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from twisted.logger import (
|
|||
|
||||
from synapse.logging.context import LoggingContextFilter
|
||||
from synapse.logging.filter import MetadataFilter
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
|
@ -147,13 +148,15 @@ https://matrix-org.github.io/synapse/v1.54/structured_logging.html
|
|||
class LoggingConfig(Config):
|
||||
section = "logging"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
if config.get("log_file"):
|
||||
raise ConfigError(LOG_FILE_ERROR)
|
||||
self.log_config = self.abspath(config.get("log_config"))
|
||||
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
|
||||
def generate_config_section(
|
||||
self, config_dir_path: str, server_name: str, **kwargs: Any
|
||||
) -> str:
|
||||
log_config = os.path.join(config_dir_path, server_name + ".log.config")
|
||||
return (
|
||||
"""\
|
||||
|
|
|
|||
|
|
@ -13,8 +13,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.check_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
|
@ -37,7 +40,7 @@ class MetricsFlags:
|
|||
class MetricsConfig(Config):
|
||||
section = "metrics"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.enable_metrics = config.get("enable_metrics", False)
|
||||
self.report_stats = config.get("report_stats", None)
|
||||
self.report_stats_endpoint = config.get(
|
||||
|
|
@ -67,7 +70,9 @@ class MetricsConfig(Config):
|
|||
"sentry.dsn field is required when sentry integration is enabled"
|
||||
)
|
||||
|
||||
def generate_config_section(self, report_stats=None, **kwargs):
|
||||
def generate_config_section(
|
||||
self, report_stats: Optional[bool] = None, **kwargs: Any
|
||||
) -> str:
|
||||
res = """\
|
||||
## Metrics ###
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,14 @@
|
|||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
||||
class ModulesConfig(Config):
|
||||
section = "modules"
|
||||
|
||||
def read_config(self, config: dict, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.loaded_modules: List[Tuple[Any, Dict]] = []
|
||||
|
||||
configured_modules = config.get("modules") or []
|
||||
|
|
@ -31,7 +32,7 @@ class ModulesConfig(Config):
|
|||
|
||||
self.loaded_modules.append(load_module(module, config_path))
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """
|
||||
## Modules ##
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class OembedConfig(Config):
|
|||
|
||||
section = "oembed"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
oembed_config: Dict[str, Any] = config.get("oembed") or {}
|
||||
|
||||
# A list of patterns which will be used.
|
||||
|
|
@ -143,7 +143,7 @@ class OembedConfig(Config):
|
|||
)
|
||||
return re.compile(pattern)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# oEmbed allows for easier embedding content from a website. It can be
|
||||
# used for generating URLs previews of services which support it.
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
|
|||
class OIDCConfig(Config):
|
||||
section = "oidc"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
|
||||
if not self.oidc_providers:
|
||||
return
|
||||
|
|
@ -66,7 +66,7 @@ class OIDCConfig(Config):
|
|||
# OIDC is enabled if we have a provider
|
||||
return bool(self.oidc_providers)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
|
||||
# and login.
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
from typing import Any, List, Tuple, Type
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
|
@ -24,7 +25,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
|
|||
class PasswordAuthProviderConfig(Config):
|
||||
section = "authproviders"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
"""Parses the old password auth providers config. The config format looks like this:
|
||||
|
||||
password_providers:
|
||||
|
|
|
|||
|
|
@ -13,13 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class PushConfig(Config):
|
||||
section = "push"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
push_config = config.get("push") or {}
|
||||
self.push_include_content = push_config.get("include_content", True)
|
||||
self.push_group_unread_count_by_room = push_config.get(
|
||||
|
|
@ -46,7 +50,7 @@ class PushConfig(Config):
|
|||
)
|
||||
self.push_include_content = not redact_content
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """
|
||||
## Push ##
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
|
|
@ -43,7 +45,7 @@ class FederationRateLimitConfig:
|
|||
class RatelimitConfig(Config):
|
||||
section = "ratelimiting"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
|
||||
# Load the new-style messages config if it exists. Otherwise fall back
|
||||
# to the old method.
|
||||
|
|
@ -142,7 +144,7 @@ class RatelimitConfig(Config):
|
|||
},
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Ratelimiting ##
|
||||
|
||||
|
|
|
|||
|
|
@ -12,14 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.check_dependencies import check_requirements
|
||||
|
||||
|
||||
class RedisConfig(Config):
|
||||
section = "redis"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
redis_config = config.get("redis") or {}
|
||||
self.redis_enabled = redis_config.get("enabled", False)
|
||||
|
||||
|
|
@ -32,7 +35,7 @@ class RedisConfig(Config):
|
|||
self.redis_port = redis_config.get("port", 6379)
|
||||
self.redis_password = redis_config.get("password")
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# Configuration for Redis when using workers. This *must* be enabled when
|
||||
# using workers (unless using old style direct TCP configuration).
|
||||
|
|
|
|||
|
|
@ -13,18 +13,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from synapse.api.constants import RoomCreationPreset
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import RoomAlias, UserID
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
from synapse.util.stringutils import random_string_with_symbols, strtobool
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
section = "registration"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.enable_registration = strtobool(
|
||||
str(config.get("enable_registration", False))
|
||||
)
|
||||
|
|
@ -196,7 +196,9 @@ class RegistrationConfig(Config):
|
|||
|
||||
self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False)
|
||||
|
||||
def generate_config_section(self, generate_secrets=False, **kwargs):
|
||||
def generate_config_section(
|
||||
self, generate_secrets: bool = False, **kwargs: Any
|
||||
) -> str:
|
||||
if generate_secrets:
|
||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||
random_string_with_symbols(50),
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from urllib.request import getproxies_environment # type: ignore
|
||||
|
||||
import attr
|
||||
|
|
@ -95,7 +95,7 @@ def parse_thumbnail_requirements(
|
|||
class ContentRepositoryConfig(Config):
|
||||
section = "media"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
|
||||
# Only enable the media repo if either the media repo is enabled or the
|
||||
# current worker app is the media repo.
|
||||
|
|
@ -224,7 +224,8 @@ class ContentRepositoryConfig(Config):
|
|||
"url_preview_accept_language"
|
||||
) or ["en"]
|
||||
|
||||
def generate_config_section(self, data_dir_path, **kwargs):
|
||||
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
|
||||
assert data_dir_path is not None
|
||||
media_store = os.path.join(data_dir_path, "media_store")
|
||||
|
||||
formatted_thumbnail_sizes = "".join(
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,7 +35,7 @@ class RetentionPurgeJob:
|
|||
class RetentionConfig(Config):
|
||||
section = "retention"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
retention_config = config.get("retention")
|
||||
if retention_config is None:
|
||||
retention_config = {}
|
||||
|
|
@ -153,7 +154,7 @@ class RetentionConfig(Config):
|
|||
RetentionPurgeJob(self.parse_duration("1d"), None, None)
|
||||
]
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# Message retention policy at the server level.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from synapse.api.constants import RoomCreationPreset
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
|
@ -32,7 +34,7 @@ class RoomDefaultEncryptionTypes:
|
|||
class RoomConfig(Config):
|
||||
section = "room"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
# Whether new, locally-created rooms should have encryption enabled
|
||||
encryption_for_room_type = config.get(
|
||||
"encryption_enabled_by_default_for_room_type",
|
||||
|
|
@ -61,7 +63,7 @@ class RoomConfig(Config):
|
|||
"Invalid value for encryption_enabled_by_default_for_room_type"
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Rooms ##
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from matrix_common.regex import glob_to_regex
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ from ._base import Config, ConfigError
|
|||
class RoomDirectoryConfig(Config):
|
||||
section = "roomdirectory"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.enable_room_list_search = config.get("enable_room_list_search", True)
|
||||
|
||||
alias_creation_rules = config.get("alias_creation_rules")
|
||||
|
|
@ -52,7 +52,7 @@ class RoomDirectoryConfig(Config):
|
|||
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
|
||||
]
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """
|
||||
# Uncomment to disable searching the public room list. When disabled
|
||||
# blocks searching local and remote room lists for local and remote
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
|
|||
class SAML2Config(Config):
|
||||
section = "saml2"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.saml2_enabled = False
|
||||
|
||||
saml2_config = config.get("saml2_config")
|
||||
|
|
@ -165,13 +165,13 @@ class SAML2Config(Config):
|
|||
config_path = saml2_config.get("config_path", None)
|
||||
if config_path is not None:
|
||||
mod = load_python_module(config_path)
|
||||
config = getattr(mod, "CONFIG", None)
|
||||
if config is None:
|
||||
config_dict_from_file = getattr(mod, "CONFIG", None)
|
||||
if config_dict_from_file is None:
|
||||
raise ConfigError(
|
||||
"Config path specified by saml2_config.config_path does not "
|
||||
"have a CONFIG property."
|
||||
)
|
||||
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
|
||||
_dict_merge(merge_dict=config_dict_from_file, into_dict=saml2_config_dict)
|
||||
|
||||
import saml2.config
|
||||
|
||||
|
|
@ -223,7 +223,7 @@ class SAML2Config(Config):
|
|||
},
|
||||
}
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
|
||||
def generate_config_section(self, config_dir_path: str, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Single sign-on integration ##
|
||||
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ class LimitRemoteRoomsConfig:
|
|||
class ServerConfig(Config):
|
||||
section = "server"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.server_name = config["server_name"]
|
||||
self.server_context = config.get("server_context", None)
|
||||
|
||||
|
|
@ -259,8 +259,8 @@ class ServerConfig(Config):
|
|||
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.soft_file_limit = config.get("soft_file_limit", 0)
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.daemonize = bool(config.get("daemonize"))
|
||||
self.print_pidfile = bool(config.get("print_pidfile"))
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.serve_server_wellknown = config.get("serve_server_wellknown", False)
|
||||
|
|
@ -680,18 +680,30 @@ class ServerConfig(Config):
|
|||
config.get("use_account_validity_in_account_status") or False
|
||||
)
|
||||
|
||||
# This is a temporary option that enables fully using the new
|
||||
# `device_lists_changes_in_room` without the backwards compat code. This
|
||||
# is primarily for testing. If enabled the server should *not* be
|
||||
# downgraded, as it may lead to missing device list updates.
|
||||
self.use_new_device_lists_changes_in_room = (
|
||||
config.get("use_new_device_lists_changes_in_room") or False
|
||||
)
|
||||
|
||||
self.rooms_to_exclude_from_sync: List[str] = (
|
||||
config.get("exclude_rooms_from_sync") or []
|
||||
)
|
||||
|
||||
def has_tls_listener(self) -> bool:
|
||||
return any(listener.tls for listener in self.listeners)
|
||||
|
||||
def generate_config_section(
|
||||
self,
|
||||
server_name,
|
||||
data_dir_path,
|
||||
open_private_ports,
|
||||
listeners,
|
||||
config_dir_path,
|
||||
**kwargs,
|
||||
):
|
||||
config_dir_path: str,
|
||||
data_dir_path: str,
|
||||
server_name: str,
|
||||
open_private_ports: bool,
|
||||
listeners: Optional[List[dict]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
ip_range_blacklist = "\n".join(
|
||||
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
|
||||
)
|
||||
|
|
@ -1234,6 +1246,15 @@ class ServerConfig(Config):
|
|||
# information about using custom templates.
|
||||
#
|
||||
#custom_template_directory: /path/to/custom/templates/
|
||||
|
||||
# List of rooms to exclude from sync responses. This is useful for server
|
||||
# administrators wishing to group users into a room without these users being able
|
||||
# to see it from their client.
|
||||
#
|
||||
# By default, no room is excluded.
|
||||
#
|
||||
#exclude_rooms_from_sync:
|
||||
# - !foo:example.com
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.types import UserID
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
|
@ -60,14 +63,14 @@ class ServerNoticesConfig(Config):
|
|||
|
||||
section = "servernotices"
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args: Any):
|
||||
super().__init__(*args)
|
||||
self.server_notices_mxid = None
|
||||
self.server_notices_mxid_display_name = None
|
||||
self.server_notices_mxid_avatar_url = None
|
||||
self.server_notices_room_name = None
|
||||
self.server_notices_mxid: Optional[str] = None
|
||||
self.server_notices_mxid_display_name: Optional[str] = None
|
||||
self.server_notices_mxid_avatar_url: Optional[str] = None
|
||||
self.server_notices_room_name: Optional[str] = None
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
c = config.get("server_notices")
|
||||
if c is None:
|
||||
return
|
||||
|
|
@ -81,5 +84,5 @@ class ServerNoticesConfig(Config):
|
|||
# todo: i18n
|
||||
self.server_notices_room_name = c.get("room_name", "Server Notices")
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return DEFAULT_CONFIG
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import logging
|
|||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
|
@ -33,7 +34,7 @@ see https://matrix-org.github.io/synapse/latest/modules/index.html
|
|||
class SpamCheckerConfig(Config):
|
||||
section = "spamchecker"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.spam_checkers: List[Tuple[Any, Dict]] = []
|
||||
|
||||
spam_checkers = config.get("spam_checker") or []
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,7 +51,7 @@ class SSOConfig(Config):
|
|||
|
||||
section = "sso"
|
||||
|
||||
def read_config(self, config, **kwargs) -> None:
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
sso_config: Dict[str, Any] = config.get("sso") or {}
|
||||
|
||||
# The sso-specific template_dir
|
||||
|
|
@ -106,7 +108,7 @@ class SSOConfig(Config):
|
|||
)
|
||||
self.sso_client_whitelist.append(login_fallback_url)
|
||||
|
||||
def generate_config_section(self, **kwargs) -> str:
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
# Additional settings to use with single-sign on systems such as OpenID Connect,
|
||||
# SAML2 and CAS.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
|
@ -36,7 +39,7 @@ class StatsConfig(Config):
|
|||
|
||||
section = "stats"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.stats_enabled = True
|
||||
stats_config = config.get("stats", None)
|
||||
if stats_config:
|
||||
|
|
@ -44,7 +47,7 @@ class StatsConfig(Config):
|
|||
if not self.stats_enabled:
|
||||
logger.warning(ROOM_STATS_DISABLED_WARN)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """
|
||||
# Settings for local room and user statistics collection. See
|
||||
# https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
|
@ -20,7 +23,7 @@ from ._base import Config
|
|||
class ThirdPartyRulesConfig(Config):
|
||||
section = "thirdpartyrules"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.third_party_event_rules = None
|
||||
|
||||
provider = config.get("third_party_event_rules", None)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Pattern
|
||||
from typing import Any, List, Optional, Pattern
|
||||
|
||||
from matrix_common.regex import glob_to_regex
|
||||
|
||||
|
|
@ -22,6 +22,7 @@ from OpenSSL import SSL, crypto
|
|||
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class TlsConfig(Config):
|
||||
section = "tls"
|
||||
|
||||
def read_config(self, config: dict, config_dir_path: str, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
|
||||
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
|
||||
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
|
||||
|
|
@ -142,13 +143,13 @@ class TlsConfig(Config):
|
|||
|
||||
def generate_config_section(
|
||||
self,
|
||||
config_dir_path,
|
||||
server_name,
|
||||
data_dir_path,
|
||||
tls_certificate_path,
|
||||
tls_private_key_path,
|
||||
**kwargs,
|
||||
):
|
||||
config_dir_path: str,
|
||||
data_dir_path: str,
|
||||
server_name: str,
|
||||
tls_certificate_path: Optional[str],
|
||||
tls_private_key_path: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""If the TLS paths are not specified the default will be certs in the
|
||||
config directory"""
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Set
|
||||
from typing import Any, Set
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.check_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
|
@ -22,7 +23,7 @@ from ._base import Config, ConfigError
|
|||
class TracerConfig(Config):
|
||||
section = "tracing"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
opentracing_config = config.get("opentracing")
|
||||
if opentracing_config is None:
|
||||
opentracing_config = {}
|
||||
|
|
@ -65,7 +66,7 @@ class TracerConfig(Config):
|
|||
)
|
||||
self.force_tracing_for_users.add(u)
|
||||
|
||||
def generate_config_section(cls, **kwargs):
|
||||
def generate_config_section(cls, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Opentracing ##
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
|
|
@ -22,7 +26,7 @@ class UserDirectoryConfig(Config):
|
|||
|
||||
section = "userdirectory"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
user_directory_config = config.get("user_directory") or {}
|
||||
self.user_directory_search_enabled = user_directory_config.get("enabled", True)
|
||||
self.user_directory_search_all_users = user_directory_config.get(
|
||||
|
|
@ -32,7 +36,7 @@ class UserDirectoryConfig(Config):
|
|||
"prefer_local_users", False
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """
|
||||
# User Directory configuration
|
||||
#
|
||||
|
|
|
|||
|
|
@ -12,13 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class VoipConfig(Config):
|
||||
section = "voip"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.turn_uris = config.get("turn_uris", [])
|
||||
self.turn_shared_secret = config.get("turn_shared_secret")
|
||||
self.turn_username = config.get("turn_username")
|
||||
|
|
@ -28,7 +32,7 @@ class VoipConfig(Config):
|
|||
)
|
||||
self.turn_allow_guests = config.get("turn_allow_guests", True)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## TURN ##
|
||||
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import (
|
||||
Config,
|
||||
ConfigError,
|
||||
|
|
@ -110,7 +112,7 @@ class WorkerConfig(Config):
|
|||
|
||||
section = "worker"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
self.worker_app = config.get("worker_app")
|
||||
|
||||
# Canonicalise worker_app so that master always has None
|
||||
|
|
@ -120,9 +122,13 @@ class WorkerConfig(Config):
|
|||
self.worker_listeners = [
|
||||
parse_listener_def(x) for x in config.get("worker_listeners", [])
|
||||
]
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_daemonize = bool(config.get("worker_daemonize"))
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
self.worker_log_config = config.get("worker_log_config")
|
||||
|
||||
worker_log_config = config.get("worker_log_config")
|
||||
if worker_log_config is not None and not isinstance(worker_log_config, str):
|
||||
raise ConfigError("worker_log_config must be a string")
|
||||
self.worker_log_config = worker_log_config
|
||||
|
||||
# The host used to connect to the main synapse
|
||||
self.worker_replication_host = config.get("worker_replication_host", None)
|
||||
|
|
@ -290,7 +296,7 @@ class WorkerConfig(Config):
|
|||
self.worker_name is None and background_tasks_instance == "master"
|
||||
) or self.worker_name == background_tasks_instance
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, **kwargs: Any) -> str:
|
||||
return """\
|
||||
## Workers ##
|
||||
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class Keyring:
|
|||
self._local_verify_keys: Dict[str, FetchKeyResult] = {}
|
||||
for key_id, key in hs.config.key.old_signing_keys.items():
|
||||
self._local_verify_keys[key_id] = FetchKeyResult(
|
||||
verify_key=key, valid_until_ts=key.expired_ts
|
||||
verify_key=key, valid_until_ts=key.expired
|
||||
)
|
||||
|
||||
vk = get_verify_key(hs.signing_key)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from nacl.signing import SigningKey
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.room_versions import (
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
|
|||
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
|
||||
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
|
||||
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
|
||||
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
|
||||
|
||||
|
||||
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
||||
|
|
@ -169,6 +170,7 @@ class ThirdPartyEventRules:
|
|||
self._on_user_deactivation_status_changed_callbacks: List[
|
||||
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
|
||||
] = []
|
||||
self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
|
||||
|
||||
def register_third_party_rules_callbacks(
|
||||
self,
|
||||
|
|
@ -187,6 +189,7 @@ class ThirdPartyEventRules:
|
|||
on_user_deactivation_status_changed: Optional[
|
||||
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
|
||||
] = None,
|
||||
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Register callbacks from modules for each hook."""
|
||||
if check_event_allowed is not None:
|
||||
|
|
@ -221,6 +224,9 @@ class ThirdPartyEventRules:
|
|||
on_user_deactivation_status_changed,
|
||||
)
|
||||
|
||||
if on_threepid_bind is not None:
|
||||
self._on_threepid_bind_callbacks.append(on_threepid_bind)
|
||||
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
|
|
@ -479,3 +485,23 @@ class ThirdPartyEventRules:
|
|||
logger.exception(
|
||||
"Failed to run module API callback %s: %s", callback, e
|
||||
)
|
||||
|
||||
async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None:
|
||||
"""Called after a threepid association has been verified and stored.
|
||||
|
||||
Note that this callback is called when an association is created on the
|
||||
local homeserver, not when it's created on an identity server (and then kept track
|
||||
of so that it can be unbound on the same IS later on).
|
||||
|
||||
Args:
|
||||
user_id: the user being associated with the threepid.
|
||||
medium: the threepid's medium.
|
||||
address: the threepid's address.
|
||||
"""
|
||||
for callback in self._on_threepid_bind_callbacks:
|
||||
try:
|
||||
await callback(user_id, medium, address)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to run module API callback %s: %s", callback, e
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from synapse.api.room_versions import (
|
|||
from synapse.events import EventBase, builder
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.federation.transport.client import SendJoinResponse
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
|
@ -154,7 +155,7 @@ class FederationClient(FederationBase):
|
|||
self,
|
||||
destination: str,
|
||||
query_type: str,
|
||||
args: dict,
|
||||
args: QueryParams,
|
||||
retry_on_dns_fail: bool = False,
|
||||
ignore_backoff: bool = False,
|
||||
) -> JsonDict:
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class FederationServer(FederationBase):
|
|||
async def on_backfill_request(
|
||||
self, origin: str, room_id: str, versions: List[str], limit: int
|
||||
) -> Tuple[int, Dict[str, Any]]:
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ class FederationServer(FederationBase):
|
|||
Tuple indicating the response status code and dictionary response
|
||||
body including `event_id`.
|
||||
"""
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
|
||||
|
|
@ -529,7 +529,7 @@ class FederationServer(FederationBase):
|
|||
# in the cache so we could return it without waiting for the linearizer
|
||||
# - but that's non-trivial to get right, and anyway somewhat defeats
|
||||
# the point of the linearizer.
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
resp: JsonDict = dict(
|
||||
await self._state_resp_cache.wrap(
|
||||
(room_id, event_id),
|
||||
|
|
@ -883,7 +883,7 @@ class FederationServer(FederationBase):
|
|||
async def on_event_auth(
|
||||
self, origin: str, room_id: str, event_id: str
|
||||
) -> Tuple[int, Dict[str, Any]]:
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
|
||||
|
|
@ -945,7 +945,7 @@ class FederationServer(FederationBase):
|
|||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> Dict[str, list]:
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
async with self._server_linearizer.queue((origin, room_id)):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from synapse.api.urls import (
|
|||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.http.matrixfederationclient import ByteParser
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -255,7 +256,7 @@ class TransportLayerClient:
|
|||
self,
|
||||
destination: str,
|
||||
query_type: str,
|
||||
args: dict,
|
||||
args: QueryParams,
|
||||
retry_on_dns_fail: bool,
|
||||
ignore_backoff: bool = False,
|
||||
prefix: str = FEDERATION_V1_PREFIX,
|
||||
|
|
@ -481,7 +482,7 @@ class TransportLayerClient:
|
|||
if third_party_instance_id:
|
||||
data["third_party_instance_id"] = third_party_instance_id
|
||||
if limit:
|
||||
data["limit"] = str(limit)
|
||||
data["limit"] = limit
|
||||
if since_token:
|
||||
data["since"] = since_token
|
||||
|
||||
|
|
@ -503,7 +504,7 @@ class TransportLayerClient:
|
|||
else:
|
||||
path = _create_v1_path("/publicRooms")
|
||||
|
||||
args: Dict[str, Any] = {
|
||||
args: Dict[str, Union[str, Iterable[str]]] = {
|
||||
"include_all_networks": "true" if include_all_networks else "false"
|
||||
}
|
||||
if third_party_instance_id:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
|
||||
|
||||
from synapse.replication.http.account_data import (
|
||||
ReplicationAddTagRestServlet,
|
||||
|
|
@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID
|
|||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
|
||||
[str, Optional[str], str, JsonDict], Awaitable
|
||||
]
|
||||
|
||||
|
||||
class AccountDataHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
|
@ -40,6 +47,44 @@ class AccountDataHandler:
|
|||
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
|
||||
self._account_data_writers = hs.config.worker.writers.account_data
|
||||
|
||||
self._on_account_data_updated_callbacks: List[
|
||||
ON_ACCOUNT_DATA_UPDATED_CALLBACK
|
||||
] = []
|
||||
|
||||
def register_module_callbacks(
|
||||
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
|
||||
) -> None:
|
||||
"""Register callbacks from modules."""
|
||||
if on_account_data_updated is not None:
|
||||
self._on_account_data_updated_callbacks.append(on_account_data_updated)
|
||||
|
||||
async def _notify_modules(
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: Optional[str],
|
||||
account_data_type: str,
|
||||
content: JsonDict,
|
||||
) -> None:
|
||||
"""Notifies modules about new account data changes.
|
||||
|
||||
A change can be either a new account data type being added, or the content
|
||||
associated with a type being changed. Account data for a given type is removed by
|
||||
changing the associated content to an empty dictionary.
|
||||
|
||||
Note that this is not called when the tags associated with a room change.
|
||||
|
||||
Args:
|
||||
user_id: The user whose account data is changing.
|
||||
room_id: The ID of the room the account data change concerns, if any.
|
||||
account_data_type: The type of the account data.
|
||||
content: The content that is now associated with this type.
|
||||
"""
|
||||
for callback in self._on_account_data_updated_callbacks:
|
||||
try:
|
||||
await callback(user_id, room_id, account_data_type, content)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to run module callback %s: %s", callback, e)
|
||||
|
||||
async def add_account_data_to_room(
|
||||
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||
) -> int:
|
||||
|
|
@ -63,6 +108,8 @@ class AccountDataHandler:
|
|||
"account_data_key", max_stream_id, users=[user_id]
|
||||
)
|
||||
|
||||
await self._notify_modules(user_id, room_id, account_data_type, content)
|
||||
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._room_data_client(
|
||||
|
|
@ -96,6 +143,9 @@ class AccountDataHandler:
|
|||
self._notifier.on_new_event(
|
||||
"account_data_key", max_stream_id, users=[user_id]
|
||||
)
|
||||
|
||||
await self._notify_modules(user_id, None, account_data_type, content)
|
||||
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._user_data_client(
|
||||
|
|
|
|||
|
|
@ -180,9 +180,9 @@ class AccountValidityHandler:
|
|||
expiring_users = await self.store.get_users_expiring_soon()
|
||||
|
||||
if expiring_users:
|
||||
for user in expiring_users:
|
||||
for user_id, expiration_ts_ms in expiring_users:
|
||||
await self._send_renewal_email(
|
||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||
user_id=user_id, expiration_ts=expiration_ts_ms
|
||||
)
|
||||
|
||||
async def send_renewal_email_to_user(self, user_id: str) -> None:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import (
|
|||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonDict,
|
||||
RoomAlias,
|
||||
RoomStreamToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
|
@ -58,6 +64,9 @@ class ApplicationServicesHandler:
|
|||
self._msc2409_to_device_messages_enabled = (
|
||||
hs.config.experimental.msc2409_to_device_messages_enabled
|
||||
)
|
||||
self._msc3202_transaction_extensions_enabled = (
|
||||
hs.config.experimental.msc3202_transaction_extensions
|
||||
)
|
||||
|
||||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
|
|
@ -204,9 +213,9 @@ class ApplicationServicesHandler:
|
|||
Args:
|
||||
stream_key: The stream the event came from.
|
||||
|
||||
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
|
||||
"to_device_key". Any other value for `stream_key` will cause this function
|
||||
to return early.
|
||||
`stream_key` can be "typing_key", "receipt_key", "presence_key",
|
||||
"to_device_key" or "device_list_key". Any other value for `stream_key`
|
||||
will cause this function to return early.
|
||||
|
||||
Ephemeral events will only be pushed to appservices that have opted into
|
||||
receiving them by setting `push_ephemeral` to true in their registration
|
||||
|
|
@ -230,6 +239,7 @@ class ApplicationServicesHandler:
|
|||
"receipt_key",
|
||||
"presence_key",
|
||||
"to_device_key",
|
||||
"device_list_key",
|
||||
):
|
||||
return
|
||||
|
||||
|
|
@ -253,15 +263,37 @@ class ApplicationServicesHandler:
|
|||
):
|
||||
return
|
||||
|
||||
# Ignore device lists if the feature flag is not enabled
|
||||
if (
|
||||
stream_key == "device_list_key"
|
||||
and not self._msc3202_transaction_extensions_enabled
|
||||
):
|
||||
return
|
||||
|
||||
# Check whether there are any appservices which have registered to receive
|
||||
# ephemeral events.
|
||||
#
|
||||
# Note that whether these events are actually relevant to these appservices
|
||||
# is decided later on.
|
||||
services = self.store.get_app_services()
|
||||
services = [
|
||||
service
|
||||
for service in self.store.get_app_services()
|
||||
if service.supports_ephemeral
|
||||
for service in services
|
||||
# Different stream keys require different support booleans
|
||||
if (
|
||||
stream_key
|
||||
in (
|
||||
"typing_key",
|
||||
"receipt_key",
|
||||
"presence_key",
|
||||
"to_device_key",
|
||||
)
|
||||
and service.supports_ephemeral
|
||||
)
|
||||
or (
|
||||
stream_key == "device_list_key"
|
||||
and service.msc3202_transaction_extensions
|
||||
)
|
||||
]
|
||||
if not services:
|
||||
# Bail out early if none of the target appservices have explicitly registered
|
||||
|
|
@ -298,10 +330,8 @@ class ApplicationServicesHandler:
|
|||
continue
|
||||
|
||||
# Since we read/update the stream position for this AS/stream
|
||||
with (
|
||||
await self._ephemeral_events_linearizer.queue(
|
||||
(service.id, stream_key)
|
||||
)
|
||||
async with self._ephemeral_events_linearizer.queue(
|
||||
(service.id, stream_key)
|
||||
):
|
||||
if stream_key == "receipt_key":
|
||||
events = await self._handle_receipts(service, new_token)
|
||||
|
|
@ -336,6 +366,20 @@ class ApplicationServicesHandler:
|
|||
service, "to_device", new_token
|
||||
)
|
||||
|
||||
elif stream_key == "device_list_key":
|
||||
device_list_summary = await self._get_device_list_summary(
|
||||
service, new_token
|
||||
)
|
||||
if device_list_summary:
|
||||
self.scheduler.enqueue_for_appservice(
|
||||
service, device_list_summary=device_list_summary
|
||||
)
|
||||
|
||||
# Persist the latest handled stream token for this appservice
|
||||
await self.store.set_appservice_stream_type_pos(
|
||||
service, "device_list", new_token
|
||||
)
|
||||
|
||||
async def _handle_typing(
|
||||
self, service: ApplicationService, new_token: int
|
||||
) -> List[JsonDict]:
|
||||
|
|
@ -542,6 +586,96 @@ class ApplicationServicesHandler:
|
|||
|
||||
return message_payload
|
||||
|
||||
async def _get_device_list_summary(
|
||||
self,
|
||||
appservice: ApplicationService,
|
||||
new_key: int,
|
||||
) -> DeviceListUpdates:
|
||||
"""
|
||||
Retrieve a list of users who have changed their device lists.
|
||||
|
||||
Args:
|
||||
appservice: The application service to retrieve device list changes for.
|
||||
new_key: The stream key of the device list change that triggered this method call.
|
||||
|
||||
Returns:
|
||||
A set of device list updates, comprised of users that the appservices needs to:
|
||||
* resync the device list of, and
|
||||
* stop tracking the device list of.
|
||||
"""
|
||||
# Fetch the last successfully processed device list update stream ID
|
||||
# for this appservice.
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
appservice, "device_list"
|
||||
)
|
||||
|
||||
# Fetch the users who have modified their device list since then.
|
||||
users_with_changed_device_lists = (
|
||||
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
|
||||
)
|
||||
|
||||
# Filter out any users the application service is not interested in
|
||||
#
|
||||
# For each user who changed their device list, we want to check whether this
|
||||
# appservice would be interested in the change.
|
||||
filtered_users_with_changed_device_lists = {
|
||||
user_id
|
||||
for user_id in users_with_changed_device_lists
|
||||
if await self._is_appservice_interested_in_device_lists_of_user(
|
||||
appservice, user_id
|
||||
)
|
||||
}
|
||||
|
||||
# Create a summary of "changed" and "left" users.
|
||||
# TODO: Calculate "left" users.
|
||||
device_list_summary = DeviceListUpdates(
|
||||
changed=filtered_users_with_changed_device_lists
|
||||
)
|
||||
|
||||
return device_list_summary
|
||||
|
||||
async def _is_appservice_interested_in_device_lists_of_user(
|
||||
self,
|
||||
appservice: ApplicationService,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether a given application service is interested in the device list
|
||||
updates of a given user.
|
||||
|
||||
The application service is interested in the user's device list updates if any
|
||||
of the following are true:
|
||||
* The user is the appservice's sender localpart user.
|
||||
* The user is in the appservice's user namespace.
|
||||
* At least one member of one room that the user is a part of is in the
|
||||
appservice's user namespace.
|
||||
* The appservice is explicitly (via room ID or alias) interested in at
|
||||
least one room that the user is in.
|
||||
|
||||
Args:
|
||||
appservice: The application service to gauge interest of.
|
||||
user_id: The ID of the user whose device list interest is in question.
|
||||
|
||||
Returns:
|
||||
True if the application service is interested in the user's device lists, False
|
||||
otherwise.
|
||||
"""
|
||||
# This method checks against both the sender localpart user as well as if the
|
||||
# user is in the appservice's user namespace.
|
||||
if appservice.is_interested_in_user(user_id):
|
||||
return True
|
||||
|
||||
# Determine whether any of the rooms the user is in justifies sending this
|
||||
# device list update to the application service.
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
for room_id in room_ids:
|
||||
# This method covers checking room members for appservice interest as well as
|
||||
# room ID and alias checks.
|
||||
if await appservice.is_interested_in_room(room_id, self.store):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def query_user_exists(self, user_id: str) -> bool:
|
||||
"""Check if any application service knows this user_id exists.
|
||||
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ class AuthHandler:
|
|||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.auth.password_enabled
|
||||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
||||
self._third_party_rules = hs.get_third_party_event_rules()
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
|
|
@ -1505,6 +1506,8 @@ class AuthHandler:
|
|||
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
|
||||
)
|
||||
|
||||
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
|
||||
|
||||
async def delete_threepid(
|
||||
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
|
||||
) -> bool:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
StreamToken,
|
||||
|
|
@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
# Whether `_handle_new_device_update_async` is currently processing.
|
||||
self._handle_new_device_update_is_processing = False
|
||||
|
||||
# If a new device update may have happened while the loop was
|
||||
# processing.
|
||||
self._handle_new_device_update_new_data = False
|
||||
|
||||
# On start up check if there are any updates pending.
|
||||
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
|
||||
|
||||
# Used to decide if we calculate outbound pokes up front or not. By
|
||||
# default we do to allow safely downgrading Synapse.
|
||||
self.use_new_device_lists_changes_in_room = (
|
||||
hs.config.server.use_new_device_lists_changes_in_room
|
||||
)
|
||||
|
||||
def _check_device_name_length(self, name: Optional[str]) -> None:
|
||||
"""
|
||||
Checks whether a device name is longer than the maximum allowed length.
|
||||
|
|
@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
# No changes to notify about, so this is a no-op.
|
||||
return
|
||||
|
||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
|
||||
hosts: Set[str] = set()
|
||||
if self.hs.is_mine_id(user_id):
|
||||
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
|
||||
hosts.discard(self.server_name)
|
||||
hosts: Optional[Set[str]] = None
|
||||
if not self.use_new_device_lists_changes_in_room:
|
||||
hosts = set()
|
||||
|
||||
set_tag("target_hosts", hosts)
|
||||
if self.hs.is_mine_id(user_id):
|
||||
for room_id in room_ids:
|
||||
joined_users = await self.store.get_users_in_room(room_id)
|
||||
hosts.update(get_domain_from_id(u) for u in joined_users)
|
||||
|
||||
set_tag("target_hosts", hosts)
|
||||
|
||||
hosts.discard(self.server_name)
|
||||
|
||||
position = await self.store.add_device_change_to_streams(
|
||||
user_id, device_ids, list(hosts)
|
||||
user_id,
|
||||
device_ids,
|
||||
hosts=hosts,
|
||||
room_ids=room_ids,
|
||||
)
|
||||
|
||||
if not position:
|
||||
|
|
@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
# specify the user ID too since the user should always get their own device list
|
||||
# updates, even if they aren't in any rooms.
|
||||
users_to_notify = users_who_share_room.union({user_id})
|
||||
self.notifier.on_new_event(
|
||||
"device_list_key", position, users={user_id}, rooms=room_ids
|
||||
)
|
||||
|
||||
self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
|
||||
# We may need to do some processing asynchronously.
|
||||
self._handle_new_device_update_async()
|
||||
|
||||
if hosts:
|
||||
logger.info(
|
||||
|
|
@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
return {"success": True}
|
||||
|
||||
@wrap_as_background_process("_handle_new_device_update_async")
|
||||
async def _handle_new_device_update_async(self) -> None:
|
||||
"""Called when we have a new local device list update that we need to
|
||||
send out over federation.
|
||||
|
||||
This happens in the background so as not to block the original request
|
||||
that generated the device update.
|
||||
"""
|
||||
if self._handle_new_device_update_is_processing:
|
||||
self._handle_new_device_update_new_data = True
|
||||
return
|
||||
|
||||
self._handle_new_device_update_is_processing = True
|
||||
|
||||
# The stream ID we processed previous iteration (if any), and the set of
|
||||
# hosts we've already poked about for this update. This is so that we
|
||||
# don't poke the same remote server about the same update repeatedly.
|
||||
current_stream_id = None
|
||||
hosts_already_sent_to: Set[str] = set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
self._handle_new_device_update_new_data = False
|
||||
rows = await self.store.get_uncoverted_outbound_room_pokes()
|
||||
if not rows:
|
||||
# If the DB returned nothing then there is nothing left to
|
||||
# do, *unless* a new device list update happened during the
|
||||
# DB query.
|
||||
if self._handle_new_device_update_new_data:
|
||||
continue
|
||||
else:
|
||||
return
|
||||
|
||||
for user_id, device_id, room_id, stream_id, opentracing_context in rows:
|
||||
joined_user_ids = await self.store.get_users_in_room(room_id)
|
||||
hosts = {get_domain_from_id(u) for u in joined_user_ids}
|
||||
hosts.discard(self.server_name)
|
||||
|
||||
# Check if we've already sent this update to some hosts
|
||||
if current_stream_id == stream_id:
|
||||
hosts -= hosts_already_sent_to
|
||||
|
||||
await self.store.add_device_list_outbound_pokes(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
room_id=room_id,
|
||||
stream_id=stream_id,
|
||||
hosts=hosts,
|
||||
context=opentracing_context,
|
||||
)
|
||||
|
||||
# Notify replication that we've updated the device list stream.
|
||||
self.notifier.notify_replication()
|
||||
|
||||
if hosts:
|
||||
logger.info(
|
||||
"Sending device list update notif for %r to: %r",
|
||||
user_id,
|
||||
hosts,
|
||||
)
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(
|
||||
host, immediate=False
|
||||
)
|
||||
log_kv(
|
||||
{"message": "sent device update to host", "host": host}
|
||||
)
|
||||
|
||||
if current_stream_id != stream_id:
|
||||
# Clear the set of hosts we've already sent to as we're
|
||||
# processing a new update.
|
||||
hosts_already_sent_to.clear()
|
||||
|
||||
hosts_already_sent_to.update(hosts)
|
||||
current_stream_id = stream_id
|
||||
|
||||
finally:
|
||||
self._handle_new_device_update_is_processing = False
|
||||
|
||||
|
||||
def _update_device_from_client_ips(
|
||||
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|
||||
|
|
@ -725,7 +833,7 @@ class DeviceListUpdater:
|
|||
async def _handle_device_updates(self, user_id: str) -> None:
|
||||
"Actually handle pending updates."
|
||||
|
||||
with (await self._remote_edu_linearizer.queue(user_id)):
|
||||
async with self._remote_edu_linearizer.queue(user_id):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
# This can happen since we batch updates
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class E2eKeysHandler:
|
|||
from_device_id: the device making the query. This is used to limit
|
||||
the number of in-flight queries at a time.
|
||||
"""
|
||||
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||
"device_keys", {}
|
||||
)
|
||||
|
|
@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater:
|
|||
device_handler = self.e2e_keys_handler.device_handler
|
||||
device_list_updater = device_handler.device_list_updater
|
||||
|
||||
with (await self._remote_edu_linearizer.queue(user_id)):
|
||||
async with self._remote_edu_linearizer.queue(user_id):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
# This can happen since we batch updates
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class E2eRoomKeysHandler:
|
|||
|
||||
# we deliberately take the lock to get keys so that changing the version
|
||||
# works atomically
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
# make sure the backup version exists
|
||||
try:
|
||||
await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
|
|
@ -126,7 +126,7 @@ class E2eRoomKeysHandler:
|
|||
"""
|
||||
|
||||
# lock for consistency with uploading
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
# make sure the backup version exists
|
||||
try:
|
||||
version_info = await self.store.get_e2e_room_keys_version_info(
|
||||
|
|
@ -187,7 +187,7 @@ class E2eRoomKeysHandler:
|
|||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# XXX: perhaps we should use a finer grained lock here?
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
|
||||
# Check that the version we're trying to upload is the current version
|
||||
try:
|
||||
|
|
@ -332,7 +332,7 @@ class E2eRoomKeysHandler:
|
|||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# lock everyone out until we've switched version
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
new_version = await self.store.create_e2e_room_keys_version(
|
||||
user_id, version_info
|
||||
)
|
||||
|
|
@ -359,7 +359,7 @@ class E2eRoomKeysHandler:
|
|||
}
|
||||
"""
|
||||
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
try:
|
||||
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
except StoreError as e:
|
||||
|
|
@ -383,7 +383,7 @@ class E2eRoomKeysHandler:
|
|||
NotFoundError: if this backup version doesn't exist
|
||||
"""
|
||||
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
try:
|
||||
await self.store.delete_e2e_room_keys_version(user_id, version)
|
||||
except StoreError as e:
|
||||
|
|
@ -413,7 +413,7 @@ class E2eRoomKeysHandler:
|
|||
raise SynapseError(
|
||||
400, "Version in body does not match", Codes.INVALID_PARAM
|
||||
)
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
async with self._upload_linearizer.queue(user_id):
|
||||
try:
|
||||
old_info = await self.store.get_e2e_room_keys_version_info(
|
||||
user_id, version
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class FederationHandler:
|
|||
return. This is used as part of the heuristic to decide if we
|
||||
should back paginate.
|
||||
"""
|
||||
with (await self._room_backfill.queue(room_id)):
|
||||
async with self._room_backfill.queue(room_id):
|
||||
return await self._maybe_backfill_inner(room_id, current_depth, limit)
|
||||
|
||||
async def _maybe_backfill_inner(
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class FederationEventHandler:
|
|||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
async with self._room_pdu_linearizer.queue(pdu.room_id):
|
||||
logger.info(
|
||||
"Acquired room lock to fetch %d missing prev_events",
|
||||
len(missing_prevs),
|
||||
|
|
@ -469,6 +469,12 @@ class FederationEventHandler:
|
|||
if context.rejected:
|
||||
raise SynapseError(400, "Join event was rejected")
|
||||
|
||||
# the remote server is responsible for sending our join event to the rest
|
||||
# of the federation. Indeed, attempting to do so will result in problems
|
||||
# when we try to look up the state before the join (to get the server list)
|
||||
# and discover that we do not have it.
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
return await self.persist_events_and_notify(room_id, [(event, context)])
|
||||
|
||||
async def backfill(
|
||||
|
|
@ -891,10 +897,24 @@ class FederationEventHandler:
|
|||
logger.debug("We are also missing %i auth events", len(missing_auth_events))
|
||||
|
||||
missing_events = missing_desired_events | missing_auth_events
|
||||
logger.debug("Fetching %i events from remote", len(missing_events))
|
||||
await self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, event_ids=missing_events
|
||||
)
|
||||
|
||||
# Making an individual request for each of 1000s of events has a lot of
|
||||
# overhead. On the other hand, we don't really want to fetch all of the events
|
||||
# if we already have most of them.
|
||||
#
|
||||
# As an arbitrary heuristic, if we are missing more than 10% of the events, then
|
||||
# we fetch the whole state.
|
||||
#
|
||||
# TODO: might it be better to have an API which lets us do an aggregate event
|
||||
# request
|
||||
if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
|
||||
logger.debug("Requesting complete state from remote")
|
||||
await self._get_state_and_persist(destination, room_id, event_id)
|
||||
else:
|
||||
logger.debug("Fetching %i events from remote", len(missing_events))
|
||||
await self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, event_ids=missing_events
|
||||
)
|
||||
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
|
|
@ -953,6 +973,27 @@ class FederationEventHandler:
|
|||
|
||||
return remote_state
|
||||
|
||||
async def _get_state_and_persist(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> None:
|
||||
"""Get the complete room state at a given event, and persist any new events
|
||||
as outliers"""
|
||||
room_version = await self._store.get_room_version(room_id)
|
||||
auth_events, state_events = await self._federation_client.get_room_state(
|
||||
destination, room_id, event_id=event_id, room_version=room_version
|
||||
)
|
||||
logger.info("/state returned %i events", len(auth_events) + len(state_events))
|
||||
|
||||
await self._auth_and_persist_outliers(
|
||||
room_id, itertools.chain(auth_events, state_events)
|
||||
)
|
||||
|
||||
# we also need the event itself.
|
||||
if not await self._store.have_seen_event(room_id, event_id):
|
||||
await self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, event_ids=(event_id,)
|
||||
)
|
||||
|
||||
async def _process_received_pdu(
|
||||
self,
|
||||
origin: str,
|
||||
|
|
|
|||
|
|
@ -858,8 +858,6 @@ class IdentityHandler:
|
|||
|
||||
if room_type is not None:
|
||||
invite_config["room_type"] = room_type
|
||||
# TODO The unstable field is deprecated and should be removed in the future.
|
||||
invite_config["org.matrix.msc3288.room_type"] = room_type
|
||||
|
||||
# If a custom web client location is available, include it in the request.
|
||||
if self._web_client_location:
|
||||
|
|
|
|||
|
|
@ -853,7 +853,7 @@ class EventCreationHandler:
|
|||
# a situation where event persistence can't keep up, causing
|
||||
# extremities to pile up, which in turn leads to state resolution
|
||||
# taking longer.
|
||||
with (await self.limiter.queue(event_dict["room_id"])):
|
||||
async with self.limiter.queue(event_dict["room_id"]):
|
||||
if txn_id and requester.access_token_id:
|
||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||
event_dict["room_id"],
|
||||
|
|
|
|||
|
|
@ -441,7 +441,14 @@ class PaginationHandler:
|
|||
if pagin_config.from_token:
|
||||
from_token = pagin_config.from_token
|
||||
else:
|
||||
from_token = self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
from_token = (
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
# We expect `/messages` to use historic pagination tokens by default but
|
||||
# `/messages` should still works with live tokens when manually provided.
|
||||
assert from_token.room_key.topological
|
||||
|
||||
if pagin_config.limit is None:
|
||||
# This shouldn't happen as we've set a default limit before this
|
||||
|
|
|
|||
|
|
@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
is_syncing: Whether or not the user is now syncing
|
||||
sync_time_msec: Time in ms when the user was last syncing
|
||||
"""
|
||||
with (await self.external_sync_linearizer.queue(process_id)):
|
||||
async with self.external_sync_linearizer.queue(process_id):
|
||||
prev_state = await self.current_state_for_user(user_id)
|
||||
|
||||
process_presence = self.external_process_to_current_syncs.setdefault(
|
||||
|
|
@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
|
||||
Used when the process has stopped/disappeared.
|
||||
"""
|
||||
with (await self.external_sync_linearizer.queue(process_id)):
|
||||
async with self.external_sync_linearizer.queue(process_id):
|
||||
process_presence = self.external_process_to_current_syncs.pop(
|
||||
process_id, set()
|
||||
)
|
||||
|
|
@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||
# We'll actually pull the presence updates for these users at the end.
|
||||
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
|
||||
|
||||
if from_key:
|
||||
if from_key is not None:
|
||||
# First get all users that have had a presence update
|
||||
updated_users = stream_change_cache.get_all_entities_changed(from_key)
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class ReadMarkerHandler:
|
|||
the read marker has changed.
|
||||
"""
|
||||
|
||||
with await self.read_marker_linearizer.queue((room_id, user_id)):
|
||||
async with self.read_marker_linearizer.queue((room_id, user_id)):
|
||||
existing_read_marker = await self.store.get_account_data_for_room_and_type(
|
||||
user_id, room_id, "m.fully_read"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
|
@ -20,12 +29,12 @@ from frozendict import frozendict
|
|||
from synapse.api.constants import RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, Requester, StreamToken
|
||||
from synapse.storage.databases.main.relations import _RelatedEvent
|
||||
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -116,7 +125,10 @@ class RelationsHandler:
|
|||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
pagination_chunk = await self._main_store.get_relations_for_event(
|
||||
# Note that ignored users are not passed into get_relations_for_event
|
||||
# below. Ignored users are handled in filter_events_for_client (and by
|
||||
# not passing them in here we should get a better cache hit rate).
|
||||
related_events, next_token = await self._main_store.get_relations_for_event(
|
||||
event_id=event_id,
|
||||
event=event,
|
||||
room_id=room_id,
|
||||
|
|
@ -130,7 +142,7 @@ class RelationsHandler:
|
|||
)
|
||||
|
||||
events = await self._main_store.get_events_as_list(
|
||||
[c["event_id"] for c in pagination_chunk.chunk]
|
||||
[e.event_id for e in related_events]
|
||||
)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
|
|
@ -152,14 +164,100 @@ class RelationsHandler:
|
|||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
return_value = await pagination_chunk.to_dict(self._main_store)
|
||||
return_value["chunk"] = serialized_events
|
||||
return_value["original_event"] = original_event
|
||||
return_value = {
|
||||
"chunk": serialized_events,
|
||||
"original_event": original_event,
|
||||
}
|
||||
|
||||
if next_token:
|
||||
return_value["next_batch"] = await next_token.to_string(self._main_store)
|
||||
|
||||
if from_token:
|
||||
return_value["prev_batch"] = await from_token.to_string(self._main_store)
|
||||
|
||||
return return_value
|
||||
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
event: EventBase,
|
||||
room_id: str,
|
||||
relation_type: str,
|
||||
ignored_users: FrozenSet[str] = frozenset(),
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
"""Get a list of events which relate to an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
event: The matching EventBase to event_id.
|
||||
room_id: The room the event belongs to.
|
||||
relation_type: The type of relation.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
List of event IDs that match relations requested. The rows are of
|
||||
the form `{"event_id": "..."}`.
|
||||
"""
|
||||
|
||||
# Call the underlying storage method, which is cached.
|
||||
related_events, next_token = await self._main_store.get_relations_for_event(
|
||||
event_id, event, room_id, relation_type, direction="f"
|
||||
)
|
||||
|
||||
# Filter out ignored users and convert to the expected format.
|
||||
related_events = [
|
||||
event for event in related_events if event.sender not in ignored_users
|
||||
]
|
||||
|
||||
return related_events, next_token
|
||||
|
||||
async def get_annotations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
limit: int = 5,
|
||||
ignored_users: FrozenSet[str] = frozenset(),
|
||||
) -> List[JsonDict]:
|
||||
"""Get a list of annotations on the event, grouped by event type and
|
||||
aggregation key, sorted by count.
|
||||
|
||||
This is used e.g. to get the what and how many reactions have happend
|
||||
on an event.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the `limit` groups.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
List of groups of annotations that match. Each row is a dict with
|
||||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
# Get the base results for all users.
|
||||
full_results = await self._main_store.get_aggregation_groups_for_event(
|
||||
event_id, room_id, limit
|
||||
)
|
||||
|
||||
# Then subtract off the results for any ignored users.
|
||||
ignored_results = await self._main_store.get_aggregation_groups_for_users(
|
||||
event_id, room_id, limit, ignored_users
|
||||
)
|
||||
|
||||
filtered_results = []
|
||||
for result in full_results:
|
||||
key = (result["type"], result["key"])
|
||||
if key in ignored_results:
|
||||
result = result.copy()
|
||||
result["count"] -= ignored_results[key]
|
||||
if result["count"] <= 0:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
return filtered_results
|
||||
|
||||
async def _get_bundled_aggregation_for_event(
|
||||
self, event: EventBase, user_id: str
|
||||
self, event: EventBase, ignored_users: FrozenSet[str]
|
||||
) -> Optional[BundledAggregations]:
|
||||
"""Generate bundled aggregations for an event.
|
||||
|
||||
|
|
@ -167,7 +265,7 @@ class RelationsHandler:
|
|||
|
||||
Args:
|
||||
event: The event to calculate bundled aggregations for.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
The bundled aggregations for an event, if bundled aggregations are
|
||||
|
|
@ -190,23 +288,125 @@ class RelationsHandler:
|
|||
# while others need more processing during serialization.
|
||||
aggregations = BundledAggregations()
|
||||
|
||||
annotations = await self._main_store.get_aggregation_groups_for_event(
|
||||
event_id, room_id
|
||||
annotations = await self.get_annotations_for_event(
|
||||
event_id, room_id, ignored_users=ignored_users
|
||||
)
|
||||
if annotations.chunk:
|
||||
aggregations.annotations = await annotations.to_dict(
|
||||
cast("DataStore", self)
|
||||
)
|
||||
if annotations:
|
||||
aggregations.annotations = {"chunk": annotations}
|
||||
|
||||
references = await self._main_store.get_relations_for_event(
|
||||
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
||||
references, next_token = await self.get_relations_for_event(
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
RelationTypes.REFERENCE,
|
||||
ignored_users=ignored_users,
|
||||
)
|
||||
if references.chunk:
|
||||
aggregations.references = await references.to_dict(cast("DataStore", self))
|
||||
if references:
|
||||
aggregations.references = {
|
||||
"chunk": [{"event_id": event.event_id} for event in references]
|
||||
}
|
||||
|
||||
if next_token:
|
||||
aggregations.references["next_batch"] = await next_token.to_string(
|
||||
self._main_store
|
||||
)
|
||||
|
||||
# Store the bundled aggregations in the event metadata for later use.
|
||||
return aggregations
|
||||
|
||||
async def get_threads_for_events(
|
||||
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
|
||||
) -> Dict[str, _ThreadAggregation]:
|
||||
"""Get the bundled aggregations for threads for the requested events.
|
||||
|
||||
Args:
|
||||
event_ids: Events to get aggregations for threads.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping event ID to the thread information.
|
||||
|
||||
May not contain a value for all requested event IDs.
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
# Fetch thread summaries.
|
||||
summaries = await self._main_store.get_thread_summaries(event_ids)
|
||||
|
||||
# Only fetch participated for a limited selection based on what had
|
||||
# summaries.
|
||||
thread_event_ids = [
|
||||
event_id for event_id, summary in summaries.items() if summary
|
||||
]
|
||||
participated = await self._main_store.get_threads_participated(
|
||||
thread_event_ids, user_id
|
||||
)
|
||||
|
||||
# Then subtract off the results for any ignored users.
|
||||
ignored_results = await self._main_store.get_threaded_messages_per_user(
|
||||
thread_event_ids, ignored_users
|
||||
)
|
||||
|
||||
# A map of event ID to the thread aggregation.
|
||||
results = {}
|
||||
|
||||
for event_id, summary in summaries.items():
|
||||
if summary:
|
||||
thread_count, latest_thread_event, edit = summary
|
||||
|
||||
# Subtract off the count of any ignored users.
|
||||
for ignored_user in ignored_users:
|
||||
thread_count -= ignored_results.get((event_id, ignored_user), 0)
|
||||
|
||||
# This is gnarly, but if the latest event is from an ignored user,
|
||||
# attempt to find one that isn't from an ignored user.
|
||||
if latest_thread_event.sender in ignored_users:
|
||||
room_id = latest_thread_event.room_id
|
||||
|
||||
# If the root event is not found, something went wrong, do
|
||||
# not include a summary of the thread.
|
||||
event = await self._event_handler.get_event(user, room_id, event_id)
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
potential_events, _ = await self.get_relations_for_event(
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
RelationTypes.THREAD,
|
||||
ignored_users,
|
||||
)
|
||||
|
||||
# If all found events are from ignored users, do not include
|
||||
# a summary of the thread.
|
||||
if not potential_events:
|
||||
continue
|
||||
|
||||
# The *last* event returned is the one that is cared about.
|
||||
event = await self._event_handler.get_event(
|
||||
user, room_id, potential_events[-1].event_id
|
||||
)
|
||||
# It is unexpected that the event will not exist.
|
||||
if event is None:
|
||||
logger.warning(
|
||||
"Unable to fetch latest event in a thread with event ID: %s",
|
||||
potential_events[-1].event_id,
|
||||
)
|
||||
continue
|
||||
latest_thread_event = event
|
||||
|
||||
results[event_id] = _ThreadAggregation(
|
||||
latest_event=latest_thread_event,
|
||||
latest_edit=edit,
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=participated[event_id],
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def get_bundled_aggregations(
|
||||
self, events: Iterable[EventBase], user_id: str
|
||||
) -> Dict[str, BundledAggregations]:
|
||||
|
|
@ -230,13 +430,21 @@ class RelationsHandler:
|
|||
# event ID -> bundled aggregation in non-serialized form.
|
||||
results: Dict[str, BundledAggregations] = {}
|
||||
|
||||
# Fetch any ignored users of the requesting user.
|
||||
ignored_users = await self._main_store.ignored_users(user_id)
|
||||
|
||||
# Fetch other relations per event.
|
||||
for event in events_by_id.values():
|
||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||
event_result = await self._get_bundled_aggregation_for_event(
|
||||
event, ignored_users
|
||||
)
|
||||
if event_result:
|
||||
results[event.event_id] = event_result
|
||||
|
||||
# Fetch any edits (but not for redacted events).
|
||||
#
|
||||
# Note that there is no use in limiting edits by ignored users since the
|
||||
# parent event should be ignored in the first place if the user is ignored.
|
||||
edits = await self._main_store.get_applicable_edits(
|
||||
[
|
||||
event_id
|
||||
|
|
@ -247,25 +455,10 @@ class RelationsHandler:
|
|||
for event_id, edit in edits.items():
|
||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||
|
||||
# Fetch thread summaries.
|
||||
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
|
||||
# Only fetch participated for a limited selection based on what had
|
||||
# summaries.
|
||||
participated = await self._main_store.get_threads_participated(
|
||||
[event_id for event_id, summary in summaries.items() if summary], user_id
|
||||
threads = await self.get_threads_for_events(
|
||||
events_by_id.keys(), user_id, ignored_users
|
||||
)
|
||||
for event_id, summary in summaries.items():
|
||||
if summary:
|
||||
thread_count, latest_thread_event, edit = summary
|
||||
results.setdefault(
|
||||
event_id, BundledAggregations()
|
||||
).thread = _ThreadAggregation(
|
||||
latest_event=latest_thread_event,
|
||||
latest_edit=edit,
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=participated[event_id],
|
||||
)
|
||||
for event_id, thread in threads.items():
|
||||
results.setdefault(event_id, BundledAggregations()).thread = thread
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -771,7 +771,9 @@ class RoomCreationHandler:
|
|||
% (user_id,),
|
||||
)
|
||||
|
||||
visibility = config.get("visibility", None)
|
||||
# The spec says rooms should default to private visibility if
|
||||
# `visibility` is not specified.
|
||||
visibility = config.get("visibility", "private")
|
||||
is_public = visibility == "public"
|
||||
|
||||
if "room_id" in config:
|
||||
|
|
@ -891,7 +893,7 @@ class RoomCreationHandler:
|
|||
#
|
||||
# we also don't need to check the requester's shadow-ban here, as we
|
||||
# have already done so above (and potentially emptied invite_list).
|
||||
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
|
||||
async with self.room_member_handler.member_linearizer.queue((room_id,)):
|
||||
content = {}
|
||||
is_direct = config.get("is_direct", None)
|
||||
if is_direct:
|
||||
|
|
@ -1452,8 +1454,8 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
|
|||
def get_current_key(self) -> RoomStreamToken:
|
||||
return self.store.get_room_max_token()
|
||||
|
||||
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
|
||||
return self.store.get_room_events_max_id(room_id)
|
||||
def get_current_key_for_room(self, room_id: str) -> Awaitable[RoomStreamToken]:
|
||||
return self.store.get_current_room_stream_token_for_room_id(room_id)
|
||||
|
||||
|
||||
class ShutdownRoomResponse(TypedDict):
|
||||
|
|
|
|||
|
|
@ -158,8 +158,8 @@ class RoomBatchHandler:
|
|||
) -> List[str]:
|
||||
"""Takes all `state_events_at_start` event dictionaries and creates/persists
|
||||
them in a floating state event chain which don't resolve into the current room
|
||||
state. They are floating because they reference no prev_events and are marked
|
||||
as outliers which disconnects them from the normal DAG.
|
||||
state. They are floating because they reference no prev_events which disconnects
|
||||
them from the normal DAG.
|
||||
|
||||
Args:
|
||||
state_events_at_start:
|
||||
|
|
@ -215,31 +215,23 @@ class RoomBatchHandler:
|
|||
room_id=room_id,
|
||||
action=membership,
|
||||
content=event_dict["content"],
|
||||
# Mark as an outlier to disconnect it from the normal DAG
|
||||
# and not show up between batches of history.
|
||||
outlier=True,
|
||||
historical=True,
|
||||
# Only the first event in the state chain should be floating.
|
||||
# The rest should hang off each other in a chain.
|
||||
allow_no_prev_events=index == 0,
|
||||
prev_event_ids=prev_event_ids_for_state_chain,
|
||||
# Since each state event is marked as an outlier, the
|
||||
# `EventContext.for_outlier()` won't have any `state_ids`
|
||||
# set and therefore can't derive any state even though the
|
||||
# prev_events are set. Also since the first event in the
|
||||
# state chain is floating with no `prev_events`, it can't
|
||||
# derive state from anywhere automatically. So we need to
|
||||
# set some state explicitly.
|
||||
# The first event in the state chain is floating with no
|
||||
# `prev_events` which means it can't derive state from
|
||||
# anywhere automatically. So we need to set some state
|
||||
# explicitly.
|
||||
#
|
||||
# Make sure to use a copy of this list because we modify it
|
||||
# later in the loop here. Otherwise it will be the same
|
||||
# reference and also update in the event when we append later.
|
||||
# reference and also update in the event when we append
|
||||
# later.
|
||||
state_event_ids=state_event_ids.copy(),
|
||||
)
|
||||
else:
|
||||
# TODO: Add some complement tests that adds state that is not member joins
|
||||
# and will use this code path. Maybe we only want to support join state events
|
||||
# and can get rid of this `else`?
|
||||
(
|
||||
event,
|
||||
_,
|
||||
|
|
@ -248,21 +240,15 @@ class RoomBatchHandler:
|
|||
state_event["sender"], app_service_requester.app_service
|
||||
),
|
||||
event_dict,
|
||||
# Mark as an outlier to disconnect it from the normal DAG
|
||||
# and not show up between batches of history.
|
||||
outlier=True,
|
||||
historical=True,
|
||||
# Only the first event in the state chain should be floating.
|
||||
# The rest should hang off each other in a chain.
|
||||
allow_no_prev_events=index == 0,
|
||||
prev_event_ids=prev_event_ids_for_state_chain,
|
||||
# Since each state event is marked as an outlier, the
|
||||
# `EventContext.for_outlier()` won't have any `state_ids`
|
||||
# set and therefore can't derive any state even though the
|
||||
# prev_events are set. Also since the first event in the
|
||||
# state chain is floating with no `prev_events`, it can't
|
||||
# derive state from anywhere automatically. So we need to
|
||||
# set some state explicitly.
|
||||
# The first event in the state chain is floating with no
|
||||
# `prev_events` which means it can't derive state from
|
||||
# anywhere automatically. So we need to set some state
|
||||
# explicitly.
|
||||
#
|
||||
# Make sure to use a copy of this list because we modify it
|
||||
# later in the loop here. Otherwise it will be the same
|
||||
|
|
|
|||
|
|
@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
|
||||
# We first linearise by the application service (to try to limit concurrent joins
|
||||
# by application services), and then by room ID.
|
||||
with (await self.member_as_limiter.queue(as_id)):
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
async with self.member_as_limiter.queue(as_id):
|
||||
async with self.member_linearizer.queue(key):
|
||||
result = await self.update_membership_locked(
|
||||
requester,
|
||||
target,
|
||||
|
|
|
|||
|
|
@ -59,8 +59,6 @@ class SearchHandler:
|
|||
self.state_store = self.storage.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._msc3666_enabled = hs.config.experimental.msc3666_enabled
|
||||
|
||||
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
|
||||
"""Retrieves room IDs of old rooms in the history of an upgraded room.
|
||||
|
||||
|
|
@ -353,22 +351,20 @@ class SearchHandler:
|
|||
state = await self.state_handler.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
aggregations = None
|
||||
if self._msc3666_enabled:
|
||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||
# Generate an iterable of EventBase for all the events that will be
|
||||
# returned, including contextual events.
|
||||
itertools.chain(
|
||||
# The events_before and events_after for each context.
|
||||
itertools.chain.from_iterable(
|
||||
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
|
||||
for context in contexts.values()
|
||||
),
|
||||
# The returned events.
|
||||
search_result.allowed_events,
|
||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||
# Generate an iterable of EventBase for all the events that will be
|
||||
# returned, including contextual events.
|
||||
itertools.chain(
|
||||
# The events_before and events_after for each context.
|
||||
itertools.chain.from_iterable(
|
||||
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
|
||||
for context in contexts.values()
|
||||
),
|
||||
user.to_string(),
|
||||
)
|
||||
# The returned events.
|
||||
search_result.allowed_events,
|
||||
),
|
||||
user.to_string(),
|
||||
)
|
||||
|
||||
# We're now about to serialize the events. We should not make any
|
||||
# blocking calls after this. Otherwise, the 'age' will be wrong.
|
||||
|
|
|
|||
|
|
@ -430,7 +430,7 @@ class SsoHandler:
|
|||
# grab a lock while we try to find a mapping for this user. This seems...
|
||||
# optimistic, especially for implementations that end up redirecting to
|
||||
# interstitial pages.
|
||||
with await self._mapping_lock.queue(auth_provider_id):
|
||||
async with self._mapping_lock.queue(auth_provider_id):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
user_id = await self.get_sso_user_by_remote_user_id(
|
||||
auth_provider_id,
|
||||
|
|
|
|||
|
|
@ -13,17 +13,7 @@
|
|||
# limitations under the License.
|
||||
import itertools
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
|
@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts
|
|||
from synapse.storage.roommember import MemberSummary
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
DeviceListUpdates,
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
Requester,
|
||||
|
|
@ -184,21 +175,6 @@ class GroupsSyncResult:
|
|||
return bool(self.join or self.invite or self.leave)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceLists:
|
||||
"""
|
||||
Attributes:
|
||||
changed: List of user_ids whose devices may have changed
|
||||
left: List of user_ids whose devices we no longer track
|
||||
"""
|
||||
|
||||
changed: Collection[str]
|
||||
left: Collection[str]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.changed or self.left)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _RoomChanges:
|
||||
"""The set of room entries to include in the sync, plus the set of joined
|
||||
|
|
@ -240,7 +216,7 @@ class SyncResult:
|
|||
knocked: List[KnockedSyncResult]
|
||||
archived: List[ArchivedSyncResult]
|
||||
to_device: List[JsonDict]
|
||||
device_lists: DeviceLists
|
||||
device_lists: DeviceListUpdates
|
||||
device_one_time_keys_count: JsonDict
|
||||
device_unused_fallback_key_types: List[str]
|
||||
groups: Optional[GroupsSyncResult]
|
||||
|
|
@ -298,6 +274,8 @@ class SyncHandler:
|
|||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||
)
|
||||
|
||||
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
|
@ -1176,8 +1154,9 @@ class SyncHandler:
|
|||
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
|
||||
)
|
||||
|
||||
logger.debug("Fetching group data")
|
||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
if self.hs_config.experimental.groups_enabled:
|
||||
logger.debug("Fetching group data")
|
||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
|
||||
num_events = 0
|
||||
|
||||
|
|
@ -1261,8 +1240,8 @@ class SyncHandler:
|
|||
newly_joined_or_invited_or_knocked_users: Set[str],
|
||||
newly_left_rooms: Set[str],
|
||||
newly_left_users: Set[str],
|
||||
) -> DeviceLists:
|
||||
"""Generate the DeviceLists section of sync
|
||||
) -> DeviceListUpdates:
|
||||
"""Generate the DeviceListUpdates section of sync
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
|
|
@ -1380,9 +1359,11 @@ class SyncHandler:
|
|||
if any(e.room_id in joined_rooms for e in entries):
|
||||
newly_left_users.discard(user_id)
|
||||
|
||||
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
|
||||
return DeviceListUpdates(
|
||||
changed=users_that_have_changed, left=newly_left_users
|
||||
)
|
||||
else:
|
||||
return DeviceLists(changed=[], left=[])
|
||||
return DeviceListUpdates()
|
||||
|
||||
async def _generate_sync_entry_for_to_device(
|
||||
self, sync_result_builder: "SyncResultBuilder"
|
||||
|
|
@ -1606,13 +1587,15 @@ class SyncHandler:
|
|||
ignored_users = await self.store.ignored_users(user_id)
|
||||
if since_token:
|
||||
room_changes = await self._get_rooms_changed(
|
||||
sync_result_builder, ignored_users
|
||||
sync_result_builder, ignored_users, self.rooms_to_exclude
|
||||
)
|
||||
tags_by_room = await self.store.get_updated_tags(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
else:
|
||||
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
|
||||
room_changes = await self._get_all_rooms(
|
||||
sync_result_builder, ignored_users, self.rooms_to_exclude
|
||||
)
|
||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||
|
||||
log_kv({"rooms_changed": len(room_changes.room_entries)})
|
||||
|
|
@ -1688,7 +1671,10 @@ class SyncHandler:
|
|||
return False
|
||||
|
||||
async def _get_rooms_changed(
|
||||
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
|
||||
self,
|
||||
sync_result_builder: "SyncResultBuilder",
|
||||
ignored_users: FrozenSet[str],
|
||||
excluded_rooms: List[str],
|
||||
) -> _RoomChanges:
|
||||
"""Determine the changes in rooms to report to the user.
|
||||
|
||||
|
|
@ -1720,7 +1706,7 @@ class SyncHandler:
|
|||
# _have_rooms_changed. We could keep the results in memory to avoid a
|
||||
# second query, at the cost of more complicated source code.
|
||||
membership_change_events = await self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
user_id, since_token.room_key, now_token.room_key, excluded_rooms
|
||||
)
|
||||
|
||||
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
|
||||
|
|
@ -1864,6 +1850,7 @@ class SyncHandler:
|
|||
full_state=False,
|
||||
since_token=since_token,
|
||||
upto_token=leave_token,
|
||||
out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1921,7 +1908,10 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
async def _get_all_rooms(
|
||||
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
|
||||
self,
|
||||
sync_result_builder: "SyncResultBuilder",
|
||||
ignored_users: FrozenSet[str],
|
||||
ignored_rooms: List[str],
|
||||
) -> _RoomChanges:
|
||||
"""Returns entries for all rooms for the user.
|
||||
|
||||
|
|
@ -1932,7 +1922,7 @@ class SyncHandler:
|
|||
Args:
|
||||
sync_result_builder
|
||||
ignored_users: Set of users ignored by user.
|
||||
|
||||
ignored_rooms: List of rooms to ignore.
|
||||
"""
|
||||
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
|
|
@ -1943,6 +1933,7 @@ class SyncHandler:
|
|||
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id=user_id,
|
||||
membership_list=Membership.LIST,
|
||||
excluded_rooms=ignored_rooms,
|
||||
)
|
||||
|
||||
room_entries = []
|
||||
|
|
@ -2126,33 +2117,41 @@ class SyncHandler:
|
|||
):
|
||||
return
|
||||
|
||||
state = await self.compute_state_delta(
|
||||
room_id,
|
||||
batch,
|
||||
sync_config,
|
||||
since_token,
|
||||
now_token,
|
||||
full_state=full_state,
|
||||
)
|
||||
if not room_builder.out_of_band:
|
||||
state = await self.compute_state_delta(
|
||||
room_id,
|
||||
batch,
|
||||
sync_config,
|
||||
since_token,
|
||||
now_token,
|
||||
full_state=full_state,
|
||||
)
|
||||
else:
|
||||
# An out of band room won't have any state changes.
|
||||
state = {}
|
||||
|
||||
summary: Optional[JsonDict] = {}
|
||||
|
||||
# we include a summary in room responses when we're lazy loading
|
||||
# members (as the client otherwise doesn't have enough info to form
|
||||
# the name itself).
|
||||
if sync_config.filter_collection.lazy_load_members() and (
|
||||
# we recalculate the summary:
|
||||
# if there are membership changes in the timeline, or
|
||||
# if membership has changed during a gappy sync, or
|
||||
# if this is an initial sync.
|
||||
any(ev.type == EventTypes.Member for ev in batch.events)
|
||||
or (
|
||||
# XXX: this may include false positives in the form of LL
|
||||
# members which have snuck into state
|
||||
batch.limited
|
||||
and any(t == EventTypes.Member for (t, k) in state)
|
||||
if (
|
||||
not room_builder.out_of_band
|
||||
and sync_config.filter_collection.lazy_load_members()
|
||||
and (
|
||||
# we recalculate the summary:
|
||||
# if there are membership changes in the timeline, or
|
||||
# if membership has changed during a gappy sync, or
|
||||
# if this is an initial sync.
|
||||
any(ev.type == EventTypes.Member for ev in batch.events)
|
||||
or (
|
||||
# XXX: this may include false positives in the form of LL
|
||||
# members which have snuck into state
|
||||
batch.limited
|
||||
and any(t == EventTypes.Member for (t, k) in state)
|
||||
)
|
||||
or since_token is None
|
||||
)
|
||||
or since_token is None
|
||||
):
|
||||
summary = await self.compute_summary(
|
||||
room_id, sync_config, batch, state, now_token
|
||||
|
|
@ -2396,6 +2395,8 @@ class RoomSyncResultBuilder:
|
|||
full_state: Whether the full state should be sent in result
|
||||
since_token: Earliest point to return events from, or None
|
||||
upto_token: Latest point to return events from.
|
||||
out_of_band: whether the events in the room are "out of band" events
|
||||
and the server isn't in the room.
|
||||
"""
|
||||
|
||||
room_id: str
|
||||
|
|
@ -2405,3 +2406,5 @@ class RoomSyncResultBuilder:
|
|||
full_state: bool
|
||||
since_token: Optional[StreamToken]
|
||||
upto_token: StreamToken
|
||||
|
||||
out_of_band: bool = False
|
||||
|
|
|
|||
|
|
@ -107,6 +107,8 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
|||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
try:
|
||||
assert self._secret is not None
|
||||
|
||||
resp_body = await self._http_client.post_urlencoded_get_json(
|
||||
self._url,
|
||||
args={
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from typing import (
|
|||
BinaryIO,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
|
|
@ -72,6 +71,7 @@ from twisted.web.iweb import (
|
|||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.types import ISynapseReactor
|
||||
|
|
@ -97,10 +97,6 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
|
|||
# the entries can either be Lists or bytes.
|
||||
RawHeaderValue = Sequence[Union[str, bytes]]
|
||||
|
||||
# the type of the query params, to be passed into `urlencode`
|
||||
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
|
||||
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
|
||||
|
||||
|
||||
def check_against_blacklist(
|
||||
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
|
||||
|
|
@ -912,7 +908,7 @@ def read_body_with_max_size(
|
|||
return d
|
||||
|
||||
|
||||
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
|
||||
def encode_query_args(args: Optional[QueryParams]) -> bytes:
|
||||
"""
|
||||
Encodes a map of query arguments to bytes which can be appended to a URL.
|
||||
|
||||
|
|
@ -925,13 +921,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
|
|||
if args is None:
|
||||
return b""
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, str):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("utf8") for v in vs]
|
||||
|
||||
query_str = urllib.parse.urlencode(encoded_args, True)
|
||||
query_str = urllib.parse.urlencode(args, True)
|
||||
|
||||
return query_str.encode("utf8")
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ from synapse.http.client import (
|
|||
read_body_with_max_size,
|
||||
)
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
|
|
@ -98,10 +99,6 @@ MAXINT = sys.maxsize
|
|||
|
||||
_next_id = 1
|
||||
|
||||
|
||||
QueryArgs = Dict[str, Union[str, List[str]]]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
|
@ -144,7 +141,7 @@ class MatrixFederationRequest:
|
|||
"""A callback to generate the JSON.
|
||||
"""
|
||||
|
||||
query: Optional[dict] = None
|
||||
query: Optional[QueryParams] = None
|
||||
"""Query arguments.
|
||||
"""
|
||||
|
||||
|
|
@ -165,10 +162,7 @@ class MatrixFederationRequest:
|
|||
|
||||
destination_bytes = self.destination.encode("ascii")
|
||||
path_bytes = self.path.encode("ascii")
|
||||
if self.query:
|
||||
query_bytes = encode_query_args(self.query)
|
||||
else:
|
||||
query_bytes = b""
|
||||
query_bytes = encode_query_args(self.query)
|
||||
|
||||
# The object is frozen so we can pre-compute this.
|
||||
uri = urllib.parse.urlunparse(
|
||||
|
|
@ -485,10 +479,7 @@ class MatrixFederationHttpClient:
|
|||
method_bytes = request.method.encode("ascii")
|
||||
destination_bytes = request.destination.encode("ascii")
|
||||
path_bytes = request.path.encode("ascii")
|
||||
if request.query:
|
||||
query_bytes = encode_query_args(request.query)
|
||||
else:
|
||||
query_bytes = b""
|
||||
query_bytes = encode_query_args(request.query)
|
||||
|
||||
scope = start_active_span(
|
||||
"outgoing-federation-request",
|
||||
|
|
@ -746,7 +737,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
|
|
@ -764,7 +755,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
|
|
@ -781,7 +772,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
data: Optional[JsonDict] = None,
|
||||
json_data_callback: Optional[Callable[[], JsonDict]] = None,
|
||||
long_retries: bool = False,
|
||||
|
|
@ -891,7 +882,7 @@ class MatrixFederationHttpClient:
|
|||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
"""Sends the specified json data using POST
|
||||
|
||||
|
|
@ -961,7 +952,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
|
|
@ -976,7 +967,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = ...,
|
||||
args: Optional[QueryParams] = ...,
|
||||
retry_on_dns_fail: bool = ...,
|
||||
timeout: Optional[int] = ...,
|
||||
ignore_backoff: bool = ...,
|
||||
|
|
@ -990,7 +981,7 @@ class MatrixFederationHttpClient:
|
|||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
|
|
@ -1085,7 +1076,7 @@ class MatrixFederationHttpClient:
|
|||
long_retries: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
"""Send a DELETE request to the remote expecting some json response
|
||||
|
||||
|
|
@ -1150,7 +1141,7 @@ class MatrixFederationHttpClient:
|
|||
destination: str,
|
||||
path: str,
|
||||
output_stream,
|
||||
args: Optional[QueryArgs] = None,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
max_size: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
|
|
|
|||
21
synapse/http/types.py
Normal file
21
synapse/http/types.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Iterable, Mapping, Union
|
||||
|
||||
# the type of the query params, to be passed into `urlencode` with `doseq=True`.
|
||||
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
|
||||
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
|
||||
|
||||
__all__ = ["QueryParams"]
|
||||
|
|
@ -289,6 +289,9 @@ class SynapseTags:
|
|||
# Uniqueish ID of a database transaction
|
||||
DB_TXN_ID = "db.txn_id"
|
||||
|
||||
# The name of the external cache
|
||||
CACHE_NAME = "cache.name"
|
||||
|
||||
|
||||
class SynapseBaggage:
|
||||
FORCE_TRACING = "synapse-force-tracing"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -52,12 +53,13 @@ from synapse.metrics._exposition import (
|
|||
start_http_server,
|
||||
)
|
||||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
|
||||
from synapse.metrics._types import Collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
METRICS_PREFIX = "/_synapse/metrics"
|
||||
|
||||
all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
|
||||
all_gauges: Dict[str, Collector] = {}
|
||||
|
||||
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
|
||||
|
||||
|
|
@ -78,11 +80,10 @@ RegistryProxy = cast(CollectorRegistry, _RegistryProxy)
|
|||
|
||||
|
||||
@attr.s(slots=True, hash=True, auto_attribs=True)
|
||||
class LaterGauge:
|
||||
|
||||
class LaterGauge(Collector):
|
||||
name: str
|
||||
desc: str
|
||||
labels: Optional[Iterable[str]] = attr.ib(hash=False)
|
||||
labels: Optional[Sequence[str]] = attr.ib(hash=False)
|
||||
# callback: should either return a value (if there are no labels for this metric),
|
||||
# or dict mapping from a label tuple to a value
|
||||
caller: Callable[
|
||||
|
|
@ -125,7 +126,7 @@ class LaterGauge:
|
|||
MetricsEntry = TypeVar("MetricsEntry")
|
||||
|
||||
|
||||
class InFlightGauge(Generic[MetricsEntry]):
|
||||
class InFlightGauge(Generic[MetricsEntry], Collector):
|
||||
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
|
||||
at any given time.
|
||||
|
||||
|
|
@ -246,7 +247,7 @@ class InFlightGauge(Generic[MetricsEntry]):
|
|||
all_gauges[self.name] = self
|
||||
|
||||
|
||||
class GaugeBucketCollector:
|
||||
class GaugeBucketCollector(Collector):
|
||||
"""Like a Histogram, but the buckets are Gauges which are updated atomically.
|
||||
|
||||
The data is updated by calling `update_data` with an iterable of measurements.
|
||||
|
|
@ -340,7 +341,7 @@ class GaugeBucketCollector:
|
|||
#
|
||||
|
||||
|
||||
class CPUMetrics:
|
||||
class CPUMetrics(Collector):
|
||||
def __init__(self) -> None:
|
||||
ticks_per_sec = 100
|
||||
try:
|
||||
|
|
@ -470,6 +471,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"Collector",
|
||||
"MetricsResource",
|
||||
"generate_latest",
|
||||
"start_http_server",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from prometheus_client.core import (
|
|||
|
||||
from twisted.internet import task
|
||||
|
||||
from synapse.metrics._types import Collector
|
||||
|
||||
"""Prometheus metrics for garbage collection"""
|
||||
|
||||
|
||||
|
|
@ -71,7 +73,7 @@ gc_time = Histogram(
|
|||
)
|
||||
|
||||
|
||||
class GCCounts:
|
||||
class GCCounts(Collector):
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
|
||||
for n, m in enumerate(gc.get_count()):
|
||||
|
|
@ -135,7 +137,7 @@ def install_gc_manager() -> None:
|
|||
#
|
||||
|
||||
|
||||
class PyPyGCStats:
|
||||
class PyPyGCStats(Collector):
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
|
||||
# @stats is a pretty-printer object with __str__() returning a nice table,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily
|
|||
|
||||
from twisted.internet import reactor
|
||||
|
||||
from synapse.metrics._types import Collector
|
||||
|
||||
#
|
||||
# Twisted reactor metrics
|
||||
#
|
||||
|
|
@ -54,7 +56,7 @@ class EpollWrapper:
|
|||
return getattr(self._poller, item)
|
||||
|
||||
|
||||
class ReactorLastSeenMetric:
|
||||
class ReactorLastSeenMetric(Collector):
|
||||
def __init__(self, epoll_wrapper: EpollWrapper):
|
||||
self._epoll_wrapper = epoll_wrapper
|
||||
|
||||
|
|
|
|||
31
synapse/metrics/_types.py
Normal file
31
synapse/metrics/_types.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
|
||||
from prometheus_client import Metric
|
||||
|
||||
try:
|
||||
from prometheus_client.registry import Collector
|
||||
except ImportError:
|
||||
# prometheus_client.Collector is new as of prometheus 0.14. We redefine it here
|
||||
# for compatibility with earlier versions.
|
||||
class _Collector(ABC):
|
||||
@abstractmethod
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
pass
|
||||
|
||||
Collector = _Collector # type: ignore
|
||||
|
|
@ -46,6 +46,7 @@ from synapse.logging.opentracing import (
|
|||
noop_context_manager,
|
||||
start_active_span,
|
||||
)
|
||||
from synapse.metrics._types import Collector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import resource
|
||||
|
|
@ -127,7 +128,7 @@ _background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set(
|
|||
_bg_metrics_lock = threading.Lock()
|
||||
|
||||
|
||||
class _Collector:
|
||||
class _Collector(Collector):
|
||||
"""A custom metrics collector for the background process metrics.
|
||||
|
||||
Ensures that all of the metrics are up-to-date with any in-flight processes
|
||||
|
|
|
|||
|
|
@ -16,11 +16,13 @@ import ctypes
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Iterable, Optional
|
||||
from typing import Iterable, Optional, overload
|
||||
|
||||
from prometheus_client import Metric
|
||||
from prometheus_client import REGISTRY, Metric
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.metrics import REGISTRY, GaugeMetricFamily
|
||||
from synapse.metrics import GaugeMetricFamily
|
||||
from synapse.metrics._types import Collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,6 +61,16 @@ def _setup_jemalloc_stats() -> None:
|
|||
|
||||
jemalloc = ctypes.CDLL(jemalloc_path)
|
||||
|
||||
@overload
|
||||
def _mallctl(
|
||||
name: str, read: Literal[True] = True, write: Optional[int] = None
|
||||
) -> int:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
|
||||
...
|
||||
|
||||
def _mallctl(
|
||||
name: str, read: bool = True, write: Optional[int] = None
|
||||
) -> Optional[int]:
|
||||
|
|
@ -134,7 +146,7 @@ def _setup_jemalloc_stats() -> None:
|
|||
except Exception as e:
|
||||
logger.warning("Failed to reload jemalloc stats: %s", e)
|
||||
|
||||
class JemallocCollector:
|
||||
class JemallocCollector(Collector):
|
||||
"""Metrics for internal jemalloc stats."""
|
||||
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue