Merge remote-tracking branch 'upstream/release-v1.57'

This commit is contained in:
Tulir Asokan 2022-04-21 13:53:47 +03:00
commit b2fa6ec9f6
248 changed files with 14616 additions and 8934 deletions

View file

@ -68,7 +68,7 @@ try:
except ImportError:
pass
__version__ = "1.56.0"
__version__ = "1.57.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -15,19 +15,19 @@
import argparse
import sys
import time
from typing import Optional
from typing import NoReturn, Optional
import nacl.signing
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None):
def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
if message:
print(message, file=sys.stderr)
sys.exit(status)
def format_plain(public_key: nacl.signing.VerifyKey):
def format_plain(public_key: VerifyKey) -> None:
print(
"%s:%s %s"
% (
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
)
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
print(
' "%s:%s": { key: "%s", expired_ts: %i }'
% (
@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
)
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
@ -94,7 +94,6 @@ def main():
message="Error reading key from file %s: %s %s"
% (file.name, type(e), e),
)
res = []
for key in res:
formatter(get_verify_key(key))

View file

@ -7,7 +7,7 @@ import sys
from synapse.config.homeserver import HomeServerConfig
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",

View file

@ -20,7 +20,7 @@ import sys
from synapse.config.logger import DEFAULT_LOG_CONFIG
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(

View file

@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
from synapse.util.stringutils import random_string
def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(

View file

@ -9,7 +9,7 @@ import bcrypt
import yaml
def prompt_for_pass():
def prompt_for_pass() -> str:
password = getpass.getpass("Password: ")
if not password:
@ -23,7 +23,7 @@ def prompt_for_pass():
return password
def main():
def main() -> None:
bcrypt_rounds = 12
password_pepper = ""

View file

@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
def main(src_repo, dest_repo):
def main(src_repo: str, dest_repo: str) -> None:
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths):
def move_media(
origin_server: str,
file_id: str,
src_paths: MediaFilePaths,
dest_paths: MediaFilePaths,
) -> None:
"""Move the given file, and any thumbnails, to the dest repo
Args:
origin_server (str):
file_id (str):
src_paths (MediaFilePaths):
dest_paths (MediaFilePaths):
origin_server:
file_id:
src_paths:
dest_paths:
"""
logger.info("%s/%s", origin_server, file_id)
@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
)
def mkdir_and_move(original_file, dest_file):
def mkdir_and_move(original_file: str, dest_file: str) -> None:
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
@ -109,10 +114,9 @@ if __name__ == "__main__":
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)
main(args.src_repo, args.dest_repo)

View file

@ -22,7 +22,7 @@ import logging
import sys
from typing import Callable, Optional
import requests as _requests
import requests
import yaml
@ -33,7 +33,6 @@ def request_registration(
shared_secret: str,
admin: bool = False,
user_type: Optional[str] = None,
requests=_requests,
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
) -> None:

View file

@ -138,9 +138,7 @@ def main() -> None:
config_args = parser.parse_args(sys.argv[1:])
config_files = find_config_files(search_paths=config_args.config_path)
config_dict = read_config_files(config_files)
config.parse_config_dict(
config_dict,
)
config.parse_config_dict(config_dict, "", "")
since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
exclude_users_with_email = config_args.exclude_emails

View file

@ -21,12 +21,29 @@ import logging
import sys
import time
import traceback
from typing import Dict, Iterable, Optional, Set
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
import yaml
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import TypedDict
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@ -35,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@ -66,8 +83,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import Clock
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
@ -97,6 +118,7 @@ BOOLEAN_COLUMNS = {
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"],
}
@ -158,12 +180,16 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None # type: Optional[str]
end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None
end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None
R = TypeVar("R")
class Store(
@ -187,17 +213,19 @@ class Store(
PresenceBackgroundUpdateStore,
GroupServerWorkerStore,
):
def execute(self, f, *args, **kwargs):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
def r(txn: LoggingTransaction) -> List[Tuple]:
txn.execute(sql, args)
return txn.fetchall()
return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
def insert_many_txn(
self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
) -> None:
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
@ -210,14 +238,15 @@ class Store(
logger.exception("Failed to insert: %s", table)
raise
def set_room_is_public(self, room_id, is_public):
# Note: the parent method is an `async def`.
def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
raise Exception(
"Attempt to set room_is_public during port_db: database not empty?"
)
class MockHomeserver:
def __init__(self, config):
def __init__(self, config: HomeServerConfig):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server.server_name
@ -225,21 +254,30 @@ class MockHomeserver:
"matrix-synapse"
)
def get_clock(self):
def get_clock(self) -> Clock:
return self.clock
def get_reactor(self):
def get_reactor(self) -> ISynapseReactor:
return reactor
def get_instance_name(self):
def get_instance_name(self) -> str:
return "master"
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class Porter:
def __init__(
self,
sqlite_config: Dict[str, Any],
progress: "Progress",
batch_size: int,
hs_config: HomeServerConfig,
):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
async def setup_table(self, table):
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = await self.postgres_store.db_pool.simple_select_one(
@ -281,7 +319,7 @@ class Porter(object):
)
else:
def delete_all(txn):
def delete_all(txn: LoggingTransaction) -> None:
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
)
@ -306,7 +344,7 @@ class Porter(object):
async def get_table_constraints(self) -> Dict[str, Set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
def _get_constraints(txn):
def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
# We can pull the information about foreign key constraints out from
# the postgres schema tables.
sql = """
@ -322,7 +360,7 @@ class Porter(object):
"""
txn.execute(sql)
results = {}
results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@ -332,8 +370,13 @@ class Porter(object):
)
async def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
self,
table: str,
postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table,
@ -380,7 +423,9 @@ class Porter(object):
while True:
def r(txn):
def r(
txn: LoggingTransaction,
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
forward_rows = []
backward_rows = []
if do_forward[0]:
@ -407,6 +452,7 @@ class Porter(object):
)
if frows or brows:
assert headers is not None
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
@ -415,7 +461,8 @@ class Porter(object):
rows = frows + brows
rows = self._convert_rows(table, headers, rows)
def insert(txn):
def insert(txn: LoggingTransaction) -> None:
assert headers is not None
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db_pool.simple_update_one_txn(
@ -437,8 +484,12 @@ class Porter(object):
return
async def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk
):
self,
postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@ -449,7 +500,7 @@ class Porter(object):
while True:
def r(txn):
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@ -463,7 +514,7 @@ class Porter(object):
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
def insert(txn: LoggingTransaction) -> None:
sql = (
"INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)"
@ -517,7 +568,7 @@ class Porter(object):
self,
db_config: DatabaseConnectionConfig,
allow_outdated_version: bool = False,
):
) -> Store:
"""Builds and returns a database store using the provided configuration.
Args:
@ -539,12 +590,13 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()
return store
async def run_background_updates_on_postgres(self):
async def run_background_updates_on_postgres(self) -> None:
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = (
await self.postgres_store.db_pool.updates.has_completed_background_updates()
@ -556,12 +608,12 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready:
await self.postgres_store.db_pool.updates.do_next_background_update(100)
await self.postgres_store.db_pool.updates.do_next_background_update(True)
postgres_ready = await (
self.postgres_store.db_pool.updates.has_completed_background_updates()
)
async def run(self):
async def run(self) -> None:
"""Ports the SQLite database to a PostgreSQL database.
When a fatal error is met, its message is assigned to the global "end_error"
@ -597,7 +649,7 @@ class Porter(object):
self.progress.set_state("Creating port tables")
def create_port_table(txn):
def create_port_table(txn: LoggingTransaction) -> None:
txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
@ -610,7 +662,7 @@ class Porter(object):
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
def alter_table(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
@ -723,12 +775,16 @@ class Porter(object):
except Exception as e:
global end_error_exec_info
end_error = str(e)
end_error_exec_info = sys.exc_info()
# Type safety: we're in an exception handler, so the exc_info() tuple
# will not be (None, None, None).
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
def _convert_rows(
self, table: str, headers: List[str], rows: List[Tuple]
) -> List[Tuple]:
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@ -736,7 +792,7 @@ class Porter(object):
class BadValueException(Exception):
pass
def conv(j, col):
def conv(j: int, col: object) -> object:
if j in bool_cols:
return bool(col)
if isinstance(col, bytes):
@ -762,7 +818,7 @@ class Porter(object):
return outrows
async def _setup_sent_transactions(self):
async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
# Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000
@ -774,10 +830,10 @@ class Porter(object):
")"
)
def r(txn):
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
headers: List[str] = [column[0] for column in txn.description]
ts_ind = headers.index("ts")
@ -791,7 +847,7 @@ class Porter(object):
if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
def insert(txn: LoggingTransaction) -> None:
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
@ -800,7 +856,7 @@ class Porter(object):
else:
max_inserted_rowid = 0
def get_start_id(txn):
def get_start_id(txn: LoggingTransaction) -> int:
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
@ -825,12 +881,13 @@ class Porter(object):
},
)
def get_sent_table_size(txn):
def get_sent_table_size(txn: LoggingTransaction) -> int:
txn.execute(
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
(size,) = txn.fetchone()
return int(size)
result = txn.fetchone()
assert result is not None
return int(result[0])
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
@ -838,25 +895,35 @@ class Porter(object):
return next_chunk, inserted_rows, total_count
async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
async def _get_remaining_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> int:
frows = cast(
List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
),
)
brows = await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
brows = cast(
List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
),
)
return frows[0][0] + brows[0][0]
async def _get_already_ported_count(self, table):
async def _get_already_ported_count(self, table: str) -> int:
rows = await self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,)
)
return rows[0][0]
async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
async def _get_total_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> Tuple[int, int]:
remaining, done = await make_deferred_yieldable(
defer.gatherResults(
[
@ -877,14 +944,17 @@ class Porter(object):
return done, remaining + done
async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
if not curr_id:
return
def r(txn):
def r(txn: LoggingTransaction) -> None:
assert curr_id is not None
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
@ -895,7 +965,7 @@ class Porter(object):
"setup_user_id_seq", find_max_generated_user_id_localpart
)
def r(txn):
def r(txn: LoggingTransaction) -> None:
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
@ -917,7 +987,7 @@ class Porter(object):
allow_none=True,
)
def _setup_events_stream_seqs_set_pos(txn):
def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
if curr_forward_id:
txn.execute(
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
@ -941,17 +1011,20 @@ class Porter(object):
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table=stream_id_table,
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
max_stream_id = cast(
int,
await self.sqlite_store.db_pool.simple_select_one_onecol(
table=stream_id_table,
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
),
)
current_stream_ids.append(max_stream_id)
next_id = max(current_stream_ids) + 1
def r(txn):
def r(txn: LoggingTransaction) -> None:
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
txn.execute(sql + " %s", (next_id,))
@ -960,14 +1033,18 @@ class Porter(object):
)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_chain_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
def r(txn):
def r(txn: LoggingTransaction) -> None:
# Presumably there is at least one row in event_auth_chains.
assert curr_chain_id is not None
txn.execute(
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id + 1,),
@ -985,15 +1062,22 @@ class Porter(object):
##############################################
class Progress(object):
class TableProgress(TypedDict):
start: int
num_done: int
total: int
perc: int
class Progress:
"""Used to report progress of the port"""
def __init__(self):
self.tables = {}
def __init__(self) -> None:
self.tables: Dict[str, TableProgress] = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
def add_table(self, table: str, cur: int, size: int) -> None:
self.tables[table] = {
"start": cur,
"num_done": cur,
@ -1001,19 +1085,22 @@ class Progress(object):
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
def update(self, table: str, num_done: int) -> None:
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
def done(self) -> None:
pass
def set_state(self, state: str) -> None:
pass
class CursesProgress(Progress):
"""Reports progress to a curses window"""
def __init__(self, stdscr):
def __init__(self, stdscr: "curses.window"):
self.stdscr = stdscr
curses.use_default_colors()
@ -1022,7 +1109,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.last_update = 0.0
self.finished = False
@ -1031,7 +1118,7 @@ class CursesProgress(Progress):
super(CursesProgress, self).__init__()
def update(self, table, num_done):
def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
@ -1042,7 +1129,7 @@ class CursesProgress(Progress):
self.render()
def render(self, force=False):
def render(self, force: bool = False) -> None:
now = time.time()
if not force and now - self.last_update < 0.2:
@ -1081,8 +1168,7 @@ class CursesProgress(Progress):
left_margin = 5
middle_space = 1
items = self.tables.items()
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@ -1115,12 +1201,12 @@ class CursesProgress(Progress):
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
def done(self) -> None:
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
def set_state(self, state: str) -> None:
self.stdscr.clear()
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh()
@ -1129,7 +1215,7 @@ class CursesProgress(Progress):
class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
def update(self, table, num_done):
def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
@ -1138,7 +1224,7 @@ class TerminalProgress(Progress):
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
)
def set_state(self, state):
def set_state(self, state: str) -> None:
print(state + "...")
@ -1146,7 +1232,7 @@ class TerminalProgress(Progress):
##############################################
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
@ -1178,15 +1264,11 @@ def main():
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
filename="port-synapse.log" if args.curses else None,
)
sqlite_config = {
"name": "sqlite3",
@ -1216,7 +1298,8 @@ def main():
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
def start(stdscr=None):
def start(stdscr: Optional["curses.window"] = None) -> None:
progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else:
@ -1230,7 +1313,7 @@ def main():
)
@defer.inlineCallbacks
def run():
def run() -> Generator["defer.Deferred[Any]", Any, None]:
with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run())

View file

@ -24,7 +24,7 @@ import signal
import subprocess
import sys
import time
from typing import Iterable, Optional
from typing import Iterable, NoReturn, Optional, TextIO
import yaml
@ -45,7 +45,7 @@ one of the following:
--------------------------------------------------------------------------------"""
def pid_running(pid):
def pid_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except OSError as err:
@ -68,7 +68,7 @@ def pid_running(pid):
return True
def write(message, colour=NORMAL, stream=sys.stdout):
def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
# Lets check if we're writing to a TTY before colouring
should_colour = False
try:
@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n")
def abort(message, colour=RED, stream=sys.stderr):
def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
write(message, colour, stream)
sys.exit(1)
@ -166,7 +166,7 @@ Worker = collections.namedtuple(
)
def main():
def main() -> None:
parser = argparse.ArgumentParser()

View file

@ -16,42 +16,47 @@
import argparse
import logging
import sys
from typing import cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.types import ISynapseReactor
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config, **kwargs):
def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__(
config.server.server_name, reactor=reactor, config=config, **kwargs
)
self.version_string = "Synapse/" + get_distribution_version_string(
"matrix-synapse"
hostname=config.server.server_name,
config=config,
reactor=reactor,
version_string="Synapse/"
+ get_distribution_version_string("matrix-synapse"),
)
def run_background_updates(hs):
def run_background_updates(hs: HomeServer) -> None:
store = hs.get_datastores().main
async def run_background_updates():
async def run_background_updates() -> None:
await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run.
reactor.stop()
def run():
def run() -> None:
# Apply all background updates on the database.
defer.ensureDeferred(
run_as_background_process("background_updates", run_background_updates)
@ -62,7 +67,7 @@ def run_background_updates(hs):
reactor.run()
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Updates a synapse database to the latest schema and optionally runs background updates"
@ -85,12 +90,10 @@ def main():
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)
# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.database_config)

View file

@ -130,7 +130,7 @@ def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Optional[Tuple[int, int, int]],
pid_file: str,
pid_file: Optional[str],
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
@ -171,6 +171,8 @@ def start_reactor(
# appearing to go backwards.
with PreserveLoggingContext():
if daemonize:
assert pid_file is not None
if print_pidfile:
print(pid_file)

View file

@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
SlavedDeviceStore,
SlavedPushRuleStore,
SlavedEventStore,
SlavedClientIpStore,
BaseSlavedStore,
RoomWorkerStore,
):

View file

@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
@ -247,7 +246,6 @@ class GenericWorkerSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedProfileStore,
SlavedClientIpStore,
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,7 +23,13 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.types import (
DeviceListUpdates,
GroupID,
JsonDict,
UserID,
get_domain_from_id,
)
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@ -400,6 +407,7 @@ class AppServiceTransaction:
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
):
self.service = service
self.id = id
@ -408,6 +416,7 @@ class AppServiceTransaction:
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
self.device_list_summary = device_list_summary
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@ -424,6 +433,7 @@ class AppServiceTransaction:
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
device_list_summary=self.device_list_summary,
txn_id=self.id,
)

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,7 +28,7 @@ from synapse.appservice import (
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None,
) -> bool:
"""
@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
if device_list_summary:
body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed),
"left": list(device_list_summary.left),
}
try:
await self.put_json(

View file

@ -72,7 +72,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock
if TYPE_CHECKING:
@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
Enqueue some data to be sent off to an application service.
@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
device_list_summary: A summary of users that the application service either needs
to refresh the device lists of, or those that the application service need no
longer track the device lists of.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages:
if (
not events
and not ephemeral
and not to_device_messages
and not device_list_summary
):
return
if events:
@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
if device_list_summary:
self.queuer.queued_device_list_summaries.setdefault(
appservice.id, []
).append(device_list_summary)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
@ -169,6 +182,8 @@ class _ServiceQueuer:
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
@ -212,7 +227,35 @@ class _ServiceQueuer:
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send:
# Consolidate any pending device list summaries into a single, up-to-date
# summary.
# Note: this code assumes that in a single DeviceListUpdates, a user will
# never be in both "changed" and "left" sets.
device_list_summary = DeviceListUpdates()
for summary in self.queued_device_list_summaries.get(service.id, []):
# For every user in the incoming "changed" set:
# * Remove them from the existing "left" set if necessary
# (as we need to start tracking them again)
# * Add them to the existing "changed" set if necessary.
device_list_summary.left.difference_update(summary.changed)
device_list_summary.changed.update(summary.changed)
# For every user in the incoming "left" set:
# * Remove them from the existing "changed" set if necessary
# (we no longer need to track them)
# * Add them to the existing "left" set if necessary.
device_list_summary.changed.difference_update(summary.left)
device_list_summary.left.update(summary.left)
self.queued_device_list_summaries.clear()
if (
not events
and not ephemeral
and not to_device_messages_to_send
# DeviceListUpdates is True if either the 'changed' or 'left' sets have
# at least one entry, otherwise False
and not device_list_summary
):
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
@ -240,6 +283,7 @@ class _ServiceQueuer:
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
device_list_summary,
)
except Exception:
logger.exception("AS request failed")
@ -322,6 +366,7 @@ class _TransactionController:
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@ -336,6 +381,7 @@ class _TransactionController:
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@ -345,6 +391,7 @@ class _TransactionController:
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceListUpdates(),
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View file

@ -702,10 +702,7 @@ class RootConfig:
return obj
def parse_config_dict(
self,
config_dict: Dict[str, Any],
config_dir_path: Optional[str] = None,
data_dir_path: Optional[str] = None,
self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
) -> None:
"""Read the information from the config dict into this Config object.

View file

@ -126,10 +126,7 @@ class RootConfig:
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def parse_config_dict(
self,
config_dict: Dict[str, Any],
config_dir_path: Optional[str] = ...,
data_dir_path: Optional[str] = ...,
self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str
) -> None: ...
def generate_config(
self,

View file

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -29,7 +31,7 @@ https://matrix-org.github.io/synapse/latest/templates.html
class AccountValidityConfig(Config):
section = "account_validity"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"""Parses the old account validity config. The config format looks like this:
account_validity:

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import Iterable
from typing import Any, Iterable
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
@ -26,12 +26,12 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
def read_config(self, config: JsonDict, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
def generate_config_section(cls, **kwargs) -> str:
def generate_config_section(cls, **kwargs: Any) -> str:
formatted_default_state_types = "\n".join(
" # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, List
from typing import Any, Dict, List
from urllib import parse as urlparse
import yaml
@ -31,12 +31,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
section = "appservice"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def generate_config_section(cls, **kwargs) -> str:
def generate_config_section(cls, **kwargs: Any) -> str:
return """\
# A list of application service config files to use
#
@ -170,6 +170,7 @@ def _load_appservice(
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
# - device list changes
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
@ -21,7 +24,7 @@ class AuthConfig(Config):
section = "auth"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
password_config = config.get("password_config", {})
if password_config is None:
password_config = {}
@ -40,7 +43,7 @@ class AuthConfig(Config):
ui_auth.get("session_timeout", 0)
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
password_config:
# Uncomment to disable password login

View file

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
@ -18,7 +21,7 @@ from ._base import Config
class BackgroundUpdateConfig(Config):
section = "background_updates"
def generate_config_section(self, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Background Updates ##
@ -52,7 +55,7 @@ class BackgroundUpdateConfig(Config):
#default_batch_size: 50
"""
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
bg_update_config = config.get("background_updates") or {}
self.update_duration_ms = bg_update_config.get(

View file

@ -16,10 +16,11 @@ import logging
import os
import re
import threading
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional
import attr
from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@ -105,7 +106,7 @@ class CacheConfig(Config):
with _CACHES_LOCK:
_CACHES.clear()
def generate_config_section(self, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Caching ##
@ -172,7 +173,7 @@ class CacheConfig(Config):
#sync_response_cache_duration: 2m
"""
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)

View file

@ -12,15 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from typing import Any
from synapse.types import JsonDict
from ._base import Config, ConfigError
class CaptchaConfig(Config):
section = "captcha"
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
recaptcha_private_key = config.get("recaptcha_private_key")
if recaptcha_private_key is not None and not isinstance(
recaptcha_private_key, str
):
raise ConfigError("recaptcha_private_key must be a string.")
self.recaptcha_private_key = recaptcha_private_key
recaptcha_public_key = config.get("recaptcha_public_key")
if recaptcha_public_key is not None and not isinstance(
recaptcha_public_key, str
):
raise ConfigError("recaptcha_public_key must be a string.")
self.recaptcha_public_key = recaptcha_public_key
self.enable_registration_captcha = config.get(
"enable_registration_captcha", False
)
@ -30,7 +46,7 @@ class CaptchaConfig(Config):
)
self.recaptcha_template = self.read_template("recaptcha.html")
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP.md for full details of configuring this.

View file

@ -16,6 +16,7 @@
from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
from ._base import Config
from ._util import validate_config
@ -29,7 +30,7 @@ class CasConfig(Config):
section = "cas"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
cas_config = config.get("cas_config", None)
self.cas_enabled = cas_config and cas_config.get("enabled", True)
@ -52,7 +53,7 @@ class CasConfig(Config):
self.cas_displayname_attribute = None
self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#

View file

@ -13,9 +13,10 @@
# limitations under the License.
from os import path
from typing import Optional
from typing import Any, Optional
from synapse.config import ConfigError
from synapse.types import JsonDict
from ._base import Config
@ -76,18 +77,18 @@ class ConsentConfig(Config):
section = "consent"
def __init__(self, *args):
def __init__(self, *args: Any):
super().__init__(*args)
self.user_consent_version: Optional[str] = None
self.user_consent_template_dir: Optional[str] = None
self.user_consent_server_notice_content = None
self.user_consent_server_notice_content: Optional[JsonDict] = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error = None
self.block_events_without_consent_error: Optional[str] = None
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
consent_config = config.get("user_consent")
self.terms_template = self.read_template("terms.html")
@ -118,5 +119,5 @@ class ConsentConfig(Config):
"policy_name", "Privacy Policy"
)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return DEFAULT_CONFIG

View file

@ -15,8 +15,10 @@
import argparse
import logging
import os
from typing import Any, List
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -121,12 +123,12 @@ class DatabaseConnectionConfig:
class DatabaseConfig(Config):
section = "database"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, *args: Any):
super().__init__(*args)
self.databases = []
self.databases: List[DatabaseConnectionConfig] = []
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra
@ -170,7 +172,7 @@ class DatabaseConfig(Config):
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)
def generate_config_section(self, data_dir_path, **kwargs) -> str:
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}

View file

@ -19,9 +19,12 @@ import email.utils
import logging
import os
from enum import Enum
from typing import Any
import attr
from synapse.types import JsonDict
from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
@ -73,7 +76,7 @@ class EmailSubjectConfig:
class EmailConfig(Config):
section = "email"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# TODO: We should separate better the email configuration from the notification
# and account validity config.
@ -354,7 +357,7 @@ class EmailConfig(Config):
path=("email", "invite_client_location"),
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return (
"""\
# Configuration for sending emails from Synapse.

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.config._base import Config
from synapse.types import JsonDict
@ -21,13 +23,11 @@ class ExperimentalConfig(Config):
section = "experimental"
def read_config(self, config: JsonDict, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
experimental = config.get("experimental_features") or {}
# MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
# MSC3666: including bundled relations in /search.
self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)
# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
# The portion of MSC3202 related to transaction extensions:
# sending device list changes, one-time key counts and fallback key
# usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)
@ -77,3 +78,6 @@ class ExperimentalConfig(Config):
# The deprecated groups feature.
self.groups_enabled: bool = experimental.get("groups_enabled", True)
# MSC2654: Unread counts
self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)

View file

@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Optional
from synapse.config._base import Config
from synapse.config._util import validate_config
from synapse.types import JsonDict
class FederationConfig(Config):
section = "federation"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist: Optional[dict] = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
@ -48,7 +49,7 @@ class FederationConfig(Config):
"allow_device_name_lookup_over_federation", True
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Federation ##

View file

@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
class GroupsConfig(Config):
section = "groups"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Uncomment to allow non-server-admin users to create groups on this server
#

View file

@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config, ConfigError
MISSING_JWT = """Missing jwt library. This is required for jwt login.
@ -24,7 +28,7 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
section = "jwt"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
@ -52,7 +56,7 @@ class JWTConfig(Config):
self.jwt_issuer = None
self.jwt_audiences = None
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# JSON web token integration. The following settings can be used to make
# Synapse JSON web tokens for authentication, instead of its internal

View file

@ -16,7 +16,7 @@
import hashlib
import logging
import os
from typing import Any, Dict, Iterator, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
import attr
import jsonschema
@ -38,6 +38,9 @@ from synapse.util.stringutils import random_string, random_string_with_symbols
from ._base import Config, ConfigError
if TYPE_CHECKING:
from signedjson.key import VerifyKeyWithExpiry
INSECURE_NOTARY_ERROR = """\
Your server is configured to accept key server responses without signature
validation or TLS certificate validation. This is likely to be very insecure. If
@ -96,11 +99,14 @@ class TrustedKeyServer:
class KeyConfig(Config):
section = "key"
def read_config(self, config, config_dir_path, **kwargs):
def read_config(
self, config: JsonDict, config_dir_path: str, **kwargs: Any
) -> None:
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
assert config_dir_path is not None
signing_key_path = config.get("signing_key_path")
if signing_key_path is None:
signing_key_path = os.path.join(
@ -169,8 +175,12 @@ class KeyConfig(Config):
self.form_secret = config.get("form_secret", None)
def generate_config_section(
self, config_dir_path, server_name, generate_secrets=False, **kwargs
):
self,
config_dir_path: str,
server_name: str,
generate_secrets: bool = False,
**kwargs: Any,
) -> str:
base_key_name = os.path.join(config_dir_path, server_name)
if generate_secrets:
@ -300,7 +310,7 @@ class KeyConfig(Config):
def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict]
) -> Dict[str, VerifyKey]:
) -> Dict[str, "VerifyKeyWithExpiry"]:
if old_signing_keys is None:
return {}
keys = {}
@ -308,8 +318,8 @@ class KeyConfig(Config):
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired_ts = key_data["expired_ts"]
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment]
verify_key.expired = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
@ -422,7 +432,7 @@ def _parse_key_servers(
server_name = server["server_name"]
result = TrustedKeyServer(server_name=server_name)
verify_keys = server.get("verify_keys")
verify_keys: Optional[Dict[str, str]] = server.get("verify_keys")
if verify_keys is not None:
result.verify_keys = {}
for key_id, key_base64 in verify_keys.items():

View file

@ -35,6 +35,7 @@ from twisted.logger import (
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
from synapse.types import JsonDict
from ._base import Config, ConfigError
@ -147,13 +148,15 @@ https://matrix-org.github.io/synapse/v1.54/structured_logging.html
class LoggingConfig(Config):
section = "logging"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
if config.get("log_file"):
raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
def generate_config_section(
self, config_dir_path: str, server_name: str, **kwargs: Any
) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return (
"""\

View file

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import attr
from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@ -37,7 +40,7 @@ class MetricsFlags:
class MetricsConfig(Config):
section = "metrics"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
self.report_stats_endpoint = config.get(
@ -67,7 +70,9 @@ class MetricsConfig(Config):
"sentry.dsn field is required when sentry integration is enabled"
)
def generate_config_section(self, report_stats=None, **kwargs):
def generate_config_section(
self, report_stats: Optional[bool] = None, **kwargs: Any
) -> str:
res = """\
## Metrics ###

View file

@ -14,13 +14,14 @@
from typing import Any, Dict, List, Tuple
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
class ModulesConfig(Config):
section = "modules"
def read_config(self, config: dict, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.loaded_modules: List[Tuple[Any, Dict]] = []
configured_modules = config.get("modules") or []
@ -31,7 +32,7 @@ class ModulesConfig(Config):
self.loaded_modules.append(load_module(module, config_path))
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """
## Modules ##

View file

@ -40,7 +40,7 @@ class OembedConfig(Config):
section = "oembed"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
oembed_config: Dict[str, Any] = config.get("oembed") or {}
# A list of patterns which will be used.
@ -143,7 +143,7 @@ class OembedConfig(Config):
)
return re.compile(pattern)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# oEmbed allows for easier embedding content from a website. It can be
# used for generating URLs previews of services which support it.

View file

@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return
@ -66,7 +66,7 @@ class OIDCConfig(Config):
# OIDC is enabled if we have a provider
return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.

View file

@ -14,6 +14,7 @@
from typing import Any, List, Tuple, Type
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@ -24,7 +25,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"""Parses the old password auth providers config. The config format looks like this:
password_providers:

View file

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
self.push_group_unread_count_by_room = push_config.get(
@ -46,7 +50,7 @@ class PushConfig(Config):
)
self.push_include_content = not redact_content
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """
## Push ##

View file

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from typing import Any, Dict, Optional
import attr
from synapse.types import JsonDict
from ._base import Config
@ -43,7 +45,7 @@ class FederationRateLimitConfig:
class RatelimitConfig(Config):
section = "ratelimiting"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
@ -142,7 +144,7 @@ class RatelimitConfig(Config):
},
)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Ratelimiting ##

View file

@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.config._base import Config
from synapse.types import JsonDict
from synapse.util.check_dependencies import check_requirements
class RedisConfig(Config):
section = "redis"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
redis_config = config.get("redis") or {}
self.redis_enabled = redis_config.get("enabled", False)
@ -32,7 +35,7 @@ class RedisConfig(Config):
self.redis_port = redis_config.get("port", 6379)
self.redis_password = redis_config.get("password")
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).

View file

@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Optional
from typing import Any, Optional
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols, strtobool
class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_registration = strtobool(
str(config.get("enable_registration", False))
)
@ -196,7 +196,9 @@ class RegistrationConfig(Config):
self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False)
def generate_config_section(self, generate_secrets=False, **kwargs):
def generate_config_section(
self, generate_secrets: bool = False, **kwargs: Any
) -> str:
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),

View file

@ -14,7 +14,7 @@
import logging
import os
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
import attr
@ -95,7 +95,7 @@ def parse_thumbnail_requirements(
class ContentRepositoryConfig(Config):
section = "media"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
@ -224,7 +224,8 @@ class ContentRepositoryConfig(Config):
"url_preview_accept_language"
) or ["en"]
def generate_config_section(self, data_dir_path, **kwargs):
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
assert data_dir_path is not None
media_store = os.path.join(data_dir_path, "media_store")
formatted_thumbnail_sizes = "".join(

View file

@ -13,11 +13,12 @@
# limitations under the License.
import logging
from typing import List, Optional
from typing import Any, List, Optional
import attr
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -34,7 +35,7 @@ class RetentionPurgeJob:
class RetentionConfig(Config):
section = "retention"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
retention_config = config.get("retention")
if retention_config is None:
retention_config = {}
@ -153,7 +154,7 @@ class RetentionConfig(Config):
RetentionPurgeJob(self.parse_duration("1d"), None, None)
]
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Message retention policy at the server level.
#

View file

@ -13,8 +13,10 @@
# limitations under the License.
import logging
from typing import Any
from synapse.api.constants import RoomCreationPreset
from synapse.types import JsonDict
from ._base import Config, ConfigError
@ -32,7 +34,7 @@ class RoomDefaultEncryptionTypes:
class RoomConfig(Config):
section = "room"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# Whether new, locally-created rooms should have encryption enabled
encryption_for_room_type = config.get(
"encryption_enabled_by_default_for_room_type",
@ -61,7 +63,7 @@ class RoomConfig(Config):
"Invalid value for encryption_enabled_by_default_for_room_type"
)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Rooms ##

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Any, List
from matrix_common.regex import glob_to_regex
@ -25,7 +25,7 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
section = "roomdirectory"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@ -52,7 +52,7 @@ class RoomDirectoryConfig(Config):
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote

View file

@ -65,7 +65,7 @@ def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
class SAML2Config(Config):
section = "saml2"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@ -165,13 +165,13 @@ class SAML2Config(Config):
config_path = saml2_config.get("config_path", None)
if config_path is not None:
mod = load_python_module(config_path)
config = getattr(mod, "CONFIG", None)
if config is None:
config_dict_from_file = getattr(mod, "CONFIG", None)
if config_dict_from_file is None:
raise ConfigError(
"Config path specified by saml2_config.config_path does not "
"have a CONFIG property."
)
_dict_merge(merge_dict=config, into_dict=saml2_config_dict)
_dict_merge(merge_dict=config_dict_from_file, into_dict=saml2_config_dict)
import saml2.config
@ -223,7 +223,7 @@ class SAML2Config(Config):
},
}
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
def generate_config_section(self, config_dir_path: str, **kwargs: Any) -> str:
return """\
## Single sign-on integration ##

View file

@ -248,7 +248,7 @@ class LimitRemoteRoomsConfig:
class ServerConfig(Config):
section = "server"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@ -259,8 +259,8 @@ class ServerConfig(Config):
self.pid_file = self.abspath(config.get("pid_file"))
self.soft_file_limit = config.get("soft_file_limit", 0)
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.daemonize = bool(config.get("daemonize"))
self.print_pidfile = bool(config.get("print_pidfile"))
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.serve_server_wellknown = config.get("serve_server_wellknown", False)
@ -680,18 +680,30 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False
)
# This is a temporary option that enables fully using the new
# `device_lists_changes_in_room` without the backwards compat code. This
# is primarily for testing. If enabled the server should *not* be
# downgraded, as it may lead to missing device list updates.
self.use_new_device_lists_changes_in_room = (
config.get("use_new_device_lists_changes_in_room") or False
)
self.rooms_to_exclude_from_sync: List[str] = (
config.get("exclude_rooms_from_sync") or []
)
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
def generate_config_section(
self,
server_name,
data_dir_path,
open_private_ports,
listeners,
config_dir_path,
**kwargs,
):
config_dir_path: str,
data_dir_path: str,
server_name: str,
open_private_ports: bool,
listeners: Optional[List[dict]],
**kwargs: Any,
) -> str:
ip_range_blacklist = "\n".join(
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
)
@ -1234,6 +1246,15 @@ class ServerConfig(Config):
# information about using custom templates.
#
#custom_template_directory: /path/to/custom/templates/
# List of rooms to exclude from sync responses. This is useful for server
# administrators wishing to group users into a room without these users being able
# to see it from their client.
#
# By default, no room is excluded.
#
#exclude_rooms_from_sync:
# - !foo:example.com
"""
% locals()
)

View file

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import UserID
from typing import Any, Optional
from synapse.types import JsonDict, UserID
from ._base import Config
@ -60,14 +63,14 @@ class ServerNoticesConfig(Config):
section = "servernotices"
def __init__(self, *args):
def __init__(self, *args: Any):
super().__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None
self.server_notices_room_name = None
self.server_notices_mxid: Optional[str] = None
self.server_notices_mxid_display_name: Optional[str] = None
self.server_notices_mxid_avatar_url: Optional[str] = None
self.server_notices_room_name: Optional[str] = None
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
c = config.get("server_notices")
if c is None:
return
@ -81,5 +84,5 @@ class ServerNoticesConfig(Config):
# todo: i18n
self.server_notices_room_name = c.get("room_name", "Server Notices")
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return DEFAULT_CONFIG

View file

@ -16,6 +16,7 @@ import logging
from typing import Any, Dict, List, Tuple
from synapse.config import ConfigError
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@ -33,7 +34,7 @@ see https://matrix-org.github.io/synapse/latest/modules/index.html
class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.spam_checkers: List[Tuple[Any, Dict]] = []
spam_checkers = config.get("spam_checker") or []

View file

@ -16,6 +16,8 @@ from typing import Any, Dict, Optional
import attr
from synapse.types import JsonDict
from ._base import Config
logger = logging.getLogger(__name__)
@ -49,7 +51,7 @@ class SSOConfig(Config):
section = "sso"
def read_config(self, config, **kwargs) -> None:
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir
@ -106,7 +108,7 @@ class SSOConfig(Config):
)
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs) -> str:
def generate_config_section(self, **kwargs: Any) -> str:
return """\
# Additional settings to use with single-sign on systems such as OpenID Connect,
# SAML2 and CAS.

View file

@ -13,6 +13,9 @@
# limitations under the License.
import logging
from typing import Any
from synapse.types import JsonDict
from ._base import Config
@ -36,7 +39,7 @@ class StatsConfig(Config):
section = "stats"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.stats_enabled = True
stats_config = config.get("stats", None)
if stats_config:
@ -44,7 +47,7 @@ class StatsConfig(Config):
if not self.stats_enabled:
logger.warning(ROOM_STATS_DISABLED_WARN)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """
# Settings for local room and user statistics collection. See
# https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html.

View file

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config
@ -20,7 +23,7 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
section = "thirdpartyrules"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.third_party_event_rules = None
provider = config.get("third_party_event_rules", None)

View file

@ -14,7 +14,7 @@
import logging
import os
from typing import List, Optional, Pattern
from typing import Any, List, Optional, Pattern
from matrix_common.regex import glob_to_regex
@ -22,6 +22,7 @@ from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
section = "tls"
def read_config(self, config: dict, config_dir_path: str, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@ -142,13 +143,13 @@ class TlsConfig(Config):
def generate_config_section(
self,
config_dir_path,
server_name,
data_dir_path,
tls_certificate_path,
tls_private_key_path,
**kwargs,
):
config_dir_path: str,
data_dir_path: str,
server_name: str,
tls_certificate_path: Optional[str],
tls_private_key_path: Optional[str],
**kwargs: Any,
) -> str:
"""If the TLS paths are not specified the default will be certs in the
config directory"""

View file

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Set
from typing import Any, Set
from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@ -22,7 +23,7 @@ from ._base import Config, ConfigError
class TracerConfig(Config):
section = "tracing"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
opentracing_config = config.get("opentracing")
if opentracing_config is None:
opentracing_config = {}
@ -65,7 +66,7 @@ class TracerConfig(Config):
)
self.force_tracing_for_users.add(u)
def generate_config_section(cls, **kwargs):
def generate_config_section(cls, **kwargs: Any) -> str:
return """\
## Opentracing ##

View file

@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
@ -22,7 +26,7 @@ class UserDirectoryConfig(Config):
section = "userdirectory"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_directory_config = config.get("user_directory") or {}
self.user_directory_search_enabled = user_directory_config.get("enabled", True)
self.user_directory_search_all_users = user_directory_config.get(
@ -32,7 +36,7 @@ class UserDirectoryConfig(Config):
"prefer_local_users", False
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """
# User Directory configuration
#

View file

@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
class VoipConfig(Config):
section = "voip"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
@ -28,7 +32,7 @@ class VoipConfig(Config):
)
self.turn_allow_guests = config.get("turn_allow_guests", True)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## TURN ##

View file

@ -14,10 +14,12 @@
# limitations under the License.
import argparse
from typing import List, Union
from typing import Any, List, Union
import attr
from synapse.types import JsonDict
from ._base import (
Config,
ConfigError,
@ -110,7 +112,7 @@ class WorkerConfig(Config):
section = "worker"
def read_config(self, config, **kwargs):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
@ -120,9 +122,13 @@ class WorkerConfig(Config):
self.worker_listeners = [
parse_listener_def(x) for x in config.get("worker_listeners", [])
]
self.worker_daemonize = config.get("worker_daemonize")
self.worker_daemonize = bool(config.get("worker_daemonize"))
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_config = config.get("worker_log_config")
worker_log_config = config.get("worker_log_config")
if worker_log_config is not None and not isinstance(worker_log_config, str):
raise ConfigError("worker_log_config must be a string")
self.worker_log_config = worker_log_config
# The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None)
@ -290,7 +296,7 @@ class WorkerConfig(Config):
self.worker_name is None and background_tasks_instance == "master"
) or self.worker_name == background_tasks_instance
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, **kwargs: Any) -> str:
return """\
## Workers ##

View file

@ -176,7 +176,7 @@ class Keyring:
self._local_verify_keys: Dict[str, FetchKeyResult] = {}
for key_id, key in hs.config.key.old_signing_keys.items():
self._local_verify_keys[key_id] = FetchKeyResult(
verify_key=key, valid_until_ts=key.expired_ts
verify_key=key, valid_until_ts=key.expired
)
vk = get_verify_key(hs.signing_key)

View file

@ -15,7 +15,7 @@ import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from nacl.signing import SigningKey
from signedjson.types import SigningKey
from synapse.api.constants import MAX_DEPTH
from synapse.api.room_versions import (

View file

@ -42,6 +42,7 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@ -169,6 +170,7 @@ class ThirdPartyEventRules:
self._on_user_deactivation_status_changed_callbacks: List[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = []
self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
def register_third_party_rules_callbacks(
self,
@ -187,6 +189,7 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
@ -221,6 +224,9 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed,
)
if on_threepid_bind is not None:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> Tuple[bool, Optional[dict]]:
@ -479,3 +485,23 @@ class ThirdPartyEventRules:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)
async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None:
"""Called after a threepid association has been verified and stored.
Note that this callback is called when an association is created on the
local homeserver, not when it's created on an identity server (and then kept track
of so that it can be unbound on the same IS later on).
Args:
user_id: the user being associated with the threepid.
medium: the threepid's medium.
address: the threepid's address.
"""
for callback in self._on_threepid_bind_callbacks:
try:
await callback(user_id, medium, address)
except Exception as e:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)

View file

@ -56,6 +56,7 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@ -154,7 +155,7 @@ class FederationClient(FederationBase):
self,
destination: str,
query_type: str,
args: dict,
args: QueryParams,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:

View file

@ -188,7 +188,7 @@ class FederationServer(FederationBase):
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))):
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@ -218,7 +218,7 @@ class FederationServer(FederationBase):
Tuple indicating the response status code and dictionary response
body including `event_id`.
"""
with (await self._server_linearizer.queue((origin, room_id))):
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@ -529,7 +529,7 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
async with self._server_linearizer.queue((origin, room_id)):
resp: JsonDict = dict(
await self._state_resp_cache.wrap(
(room_id, event_id),
@ -883,7 +883,7 @@ class FederationServer(FederationBase):
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))):
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@ -945,7 +945,7 @@ class FederationServer(FederationBase):
latest_events: List[str],
limit: int,
) -> Dict[str, list]:
with (await self._server_linearizer.queue((origin, room_id))):
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)

View file

@ -44,6 +44,7 @@ from synapse.api.urls import (
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -255,7 +256,7 @@ class TransportLayerClient:
self,
destination: str,
query_type: str,
args: dict,
args: QueryParams,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
prefix: str = FEDERATION_V1_PREFIX,
@ -481,7 +482,7 @@ class TransportLayerClient:
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
data["limit"] = str(limit)
data["limit"] = limit
if since_token:
data["since"] = since_token
@ -503,7 +504,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/publicRooms")
args: Dict[str, Any] = {
args: Dict[str, Union[str, Iterable[str]]] = {
"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:

View file

@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@ -27,6 +28,12 @@ from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
[str, Optional[str], str, JsonDict], Awaitable
]
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
@ -40,6 +47,44 @@ class AccountDataHandler:
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
self._on_account_data_updated_callbacks: List[
ON_ACCOUNT_DATA_UPDATED_CALLBACK
] = []
def register_module_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
) -> None:
"""Register callbacks from modules."""
if on_account_data_updated is not None:
self._on_account_data_updated_callbacks.append(on_account_data_updated)
async def _notify_modules(
self,
user_id: str,
room_id: Optional[str],
account_data_type: str,
content: JsonDict,
) -> None:
"""Notifies modules about new account data changes.
A change can be either a new account data type being added, or the content
associated with a type being changed. Account data for a given type is removed by
changing the associated content to an empty dictionary.
Note that this is not called when the tags associated with a room change.
Args:
user_id: The user whose account data is changing.
room_id: The ID of the room the account data change concerns, if any.
account_data_type: The type of the account data.
content: The content that is now associated with this type.
"""
for callback in self._on_account_data_updated_callbacks:
try:
await callback(user_id, room_id, account_data_type, content)
except Exception as e:
logger.exception("Failed to run module callback %s: %s", callback, e)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
@ -63,6 +108,8 @@ class AccountDataHandler:
"account_data_key", max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, room_id, account_data_type, content)
return max_stream_id
else:
response = await self._room_data_client(
@ -96,6 +143,9 @@ class AccountDataHandler:
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
await self._notify_modules(user_id, None, account_data_type, content)
return max_stream_id
else:
response = await self._user_data_client(

View file

@ -180,9 +180,9 @@ class AccountValidityHandler:
expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
user_id=user_id, expiration_ts=expiration_ts_ms
)
async def send_renewal_email_to_user(self, user_id: str) -> None:

View file

@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.types import (
DeviceListUpdates,
JsonDict,
RoomAlias,
RoomStreamToken,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
@ -58,6 +64,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self._msc3202_transaction_extensions_enabled = (
hs.config.experimental.msc3202_transaction_extensions
)
self.current_max = 0
self.is_processing = False
@ -204,9 +213,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key",
"to_device_key" or "device_list_key". Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@ -230,6 +239,7 @@ class ApplicationServicesHandler:
"receipt_key",
"presence_key",
"to_device_key",
"device_list_key",
):
return
@ -253,15 +263,37 @@ class ApplicationServicesHandler:
):
return
# Ignore device lists if the feature flag is not enabled
if (
stream_key == "device_list_key"
and not self._msc3202_transaction_extensions_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = self.store.get_app_services()
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
for service in services
# Different stream keys require different support booleans
if (
stream_key
in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
)
and service.supports_ephemeral
)
or (
stream_key == "device_list_key"
and service.msc3202_transaction_extensions
)
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
@ -298,10 +330,8 @@ class ApplicationServicesHandler:
continue
# Since we read/update the stream position for this AS/stream
with (
await self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
)
async with self._ephemeral_events_linearizer.queue(
(service.id, stream_key)
):
if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token)
@ -336,6 +366,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
elif stream_key == "device_list_key":
device_list_summary = await self._get_device_list_summary(
service, new_token
)
if device_list_summary:
self.scheduler.enqueue_for_appservice(
service, device_list_summary=device_list_summary
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "device_list", new_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@ -542,6 +586,96 @@ class ApplicationServicesHandler:
return message_payload
async def _get_device_list_summary(
self,
appservice: ApplicationService,
new_key: int,
) -> DeviceListUpdates:
"""
Retrieve a list of users who have changed their device lists.
Args:
appservice: The application service to retrieve device list changes for.
new_key: The stream key of the device list change that triggered this method call.
Returns:
A set of device list updates, comprised of users that the appservices needs to:
* resync the device list of, and
* stop tracking the device list of.
"""
# Fetch the last successfully processed device list update stream ID
# for this appservice.
from_key = await self.store.get_type_stream_id_for_appservice(
appservice, "device_list"
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
)
# Filter out any users the application service is not interested in
#
# For each user who changed their device list, we want to check whether this
# appservice would be interested in the change.
filtered_users_with_changed_device_lists = {
user_id
for user_id in users_with_changed_device_lists
if await self._is_appservice_interested_in_device_lists_of_user(
appservice, user_id
)
}
# Create a summary of "changed" and "left" users.
# TODO: Calculate "left" users.
device_list_summary = DeviceListUpdates(
changed=filtered_users_with_changed_device_lists
)
return device_list_summary
async def _is_appservice_interested_in_device_lists_of_user(
self,
appservice: ApplicationService,
user_id: str,
) -> bool:
"""
Returns whether a given application service is interested in the device list
updates of a given user.
The application service is interested in the user's device list updates if any
of the following are true:
* The user is the appservice's sender localpart user.
* The user is in the appservice's user namespace.
* At least one member of one room that the user is a part of is in the
appservice's user namespace.
* The appservice is explicitly (via room ID or alias) interested in at
least one room that the user is in.
Args:
appservice: The application service to gauge interest of.
user_id: The ID of the user whose device list interest is in question.
Returns:
True if the application service is interested in the user's device lists, False
otherwise.
"""
# This method checks against both the sender localpart user as well as if the
# user is in the appservice's user namespace.
if appservice.is_interested_in_user(user_id):
return True
# Determine whether any of the rooms the user is in justifies sending this
# device list update to the application service.
room_ids = await self.store.get_rooms_for_user(user_id)
for room_id in room_ids:
# This method covers checking room members for appservice interest as well as
# room ID and alias checks.
if await appservice.is_interested_in_room(room_id, self.store):
return True
return False
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.

View file

@ -211,6 +211,7 @@ class AuthHandler:
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@ -1505,6 +1506,8 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
) -> bool:

View file

@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import (
JsonDict,
StreamToken,
@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# Whether `_handle_new_device_update_async` is currently processing.
self._handle_new_device_update_is_processing = False
# If a new device update may have happened while the loop was
# processing.
self._handle_new_device_update_new_data = False
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
# Used to decide if we calculate outbound pokes up front or not. By
# default we do to allow safely downgrading Synapse.
self.use_new_device_lists_changes_in_room = (
hs.config.server.use_new_device_lists_changes_in_room
)
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
room_ids = await self.store.get_rooms_for_user(user_id)
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
hosts: Optional[Set[str]] = None
if not self.use_new_device_lists_changes_in_room:
hosts = set()
set_tag("target_hosts", hosts)
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
joined_users = await self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in joined_users)
set_tag("target_hosts", hosts)
hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
user_id,
device_ids,
hosts=hosts,
room_ids=room_ids,
)
if not position:
@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
users_to_notify = users_who_share_room.union({user_id})
self.notifier.on_new_event(
"device_list_key", position, users={user_id}, rooms=room_ids
)
self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
# We may need to do some processing asynchronously.
self._handle_new_device_update_async()
if hosts:
logger.info(
@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:
"""Called when we have a new local device list update that we need to
send out over federation.
This happens in the background so as not to block the original request
that generated the device update.
"""
if self._handle_new_device_update_is_processing:
self._handle_new_device_update_new_data = True
return
self._handle_new_device_update_is_processing = True
# The stream ID we processed previous iteration (if any), and the set of
# hosts we've already poked about for this update. This is so that we
# don't poke the same remote server about the same update repeatedly.
current_stream_id = None
hosts_already_sent_to: Set[str] = set()
try:
while True:
self._handle_new_device_update_new_data = False
rows = await self.store.get_uncoverted_outbound_room_pokes()
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
if self._handle_new_device_update_new_data:
continue
else:
return
for user_id, device_id, room_id, stream_id, opentracing_context in rows:
joined_user_ids = await self.store.get_users_in_room(room_id)
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts.discard(self.server_name)
# Check if we've already sent this update to some hosts
if current_stream_id == stream_id:
hosts -= hosts_already_sent_to
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
if hosts:
logger.info(
"Sending device list update notif for %r to: %r",
user_id,
hosts,
)
for host in hosts:
self.federation_sender.send_device_messages(
host, immediate=False
)
log_kv(
{"message": "sent device update to host", "host": host}
)
if current_stream_id != stream_id:
# Clear the set of hosts we've already sent to as we're
# processing a new update.
hosts_already_sent_to.clear()
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
finally:
self._handle_new_device_update_is_processing = False
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@ -725,7 +833,7 @@ class DeviceListUpdater:
async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."
with (await self._remote_edu_linearizer.queue(user_id)):
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates

View file

@ -118,7 +118,7 @@ class E2eKeysHandler:
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater:
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
with (await self._remote_edu_linearizer.queue(user_id)):
async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates

View file

@ -83,7 +83,7 @@ class E2eRoomKeysHandler:
# we deliberately take the lock to get keys so that changing the version
# works atomically
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
await self.store.get_e2e_room_keys_version_info(user_id, version)
@ -126,7 +126,7 @@ class E2eRoomKeysHandler:
"""
# lock for consistency with uploading
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# make sure the backup version exists
try:
version_info = await self.store.get_e2e_room_keys_version_info(
@ -187,7 +187,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
# Check that the version we're trying to upload is the current version
try:
@ -332,7 +332,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
@ -359,7 +359,7 @@ class E2eRoomKeysHandler:
}
"""
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
@ -383,7 +383,7 @@ class E2eRoomKeysHandler:
NotFoundError: if this backup version doesn't exist
"""
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
@ -413,7 +413,7 @@ class E2eRoomKeysHandler:
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
with (await self._upload_linearizer.queue(user_id)):
async with self._upload_linearizer.queue(user_id):
try:
old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version

View file

@ -151,7 +151,7 @@ class FederationHandler:
return. This is used as part of the heuristic to decide if we
should back paginate.
"""
with (await self._room_backfill.queue(room_id)):
async with self._room_backfill.queue(room_id):
return await self._maybe_backfill_inner(room_id, current_depth, limit)
async def _maybe_backfill_inner(

View file

@ -224,7 +224,7 @@ class FederationEventHandler:
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
async with self._room_pdu_linearizer.queue(pdu.room_id):
logger.info(
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
@ -469,6 +469,12 @@ class FederationEventHandler:
if context.rejected:
raise SynapseError(400, "Join event was rejected")
# the remote server is responsible for sending our join event to the rest
# of the federation. Indeed, attempting to do so will result in problems
# when we try to look up the state before the join (to get the server list)
# and discover that we do not have it.
event.internal_metadata.proactively_send = False
return await self.persist_events_and_notify(room_id, [(event, context)])
async def backfill(
@ -891,10 +897,24 @@ class FederationEventHandler:
logger.debug("We are also missing %i auth events", len(missing_auth_events))
missing_events = missing_desired_events | missing_auth_events
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)
# Making an individual request for each of 1000s of events has a lot of
# overhead. On the other hand, we don't really want to fetch all of the events
# if we already have most of them.
#
# As an arbitrary heuristic, if we are missing more than 10% of the events, then
# we fetch the whole state.
#
# TODO: might it be better to have an API which lets us do an aggregate event
# request
if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
logger.debug("Requesting complete state from remote")
await self._get_state_and_persist(destination, room_id, event_id)
else:
logger.debug("Fetching %i events from remote", len(missing_events))
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
@ -953,6 +973,27 @@ class FederationEventHandler:
return remote_state
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
) -> None:
"""Get the complete room state at a given event, and persist any new events
as outliers"""
room_version = await self._store.get_room_version(room_id)
auth_events, state_events = await self._federation_client.get_room_state(
destination, room_id, event_id=event_id, room_version=room_version
)
logger.info("/state returned %i events", len(auth_events) + len(state_events))
await self._auth_and_persist_outliers(
room_id, itertools.chain(auth_events, state_events)
)
# we also need the event itself.
if not await self._store.have_seen_event(room_id, event_id):
await self._get_events_and_persist(
destination=destination, room_id=room_id, event_ids=(event_id,)
)
async def _process_received_pdu(
self,
origin: str,

View file

@ -858,8 +858,6 @@ class IdentityHandler:
if room_type is not None:
invite_config["room_type"] = room_type
# TODO The unstable field is deprecated and should be removed in the future.
invite_config["org.matrix.msc3288.room_type"] = room_type
# If a custom web client location is available, include it in the request.
if self._web_client_location:

View file

@ -853,7 +853,7 @@ class EventCreationHandler:
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
with (await self.limiter.queue(event_dict["room_id"])):
async with self.limiter.queue(event_dict["room_id"]):
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
event_dict["room_id"],

View file

@ -441,7 +441,14 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
else:
from_token = self.hs.get_event_sources().get_current_token_for_pagination()
from_token = (
await self.hs.get_event_sources().get_current_token_for_pagination(
room_id
)
)
# We expect `/messages` to use historic pagination tokens by default but
# `/messages` should still works with live tokens when manually provided.
assert from_token.room_key.topological
if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this

View file

@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler):
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
with (await self.external_sync_linearizer.queue(process_id)):
async with self.external_sync_linearizer.queue(process_id):
prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler):
Used when the process has stopped/disappeared.
"""
with (await self.external_sync_linearizer.queue(process_id)):
async with self.external_sync_linearizer.queue(process_id):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
@ -1625,7 +1625,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# We'll actually pull the presence updates for these users at the end.
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
if from_key:
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)

View file

@ -41,7 +41,7 @@ class ReadMarkerHandler:
the read marker has changed.
"""
with await self.read_marker_linearizer.queue((room_id, user_id)):
async with self.read_marker_linearizer.queue((room_id, user_id)):
existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read"
)

View file

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
import attr
from frozendict import frozendict
@ -20,12 +29,12 @@ from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -116,7 +125,10 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@ -130,7 +142,7 @@ class RelationsHandler:
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
[e.event_id for e in related_events]
)
events = await filter_events_for_client(
@ -152,14 +164,100 @@ class RelationsHandler:
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return_value = {
"chunk": serialized_events,
"original_event": original_event,
}
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value
async def get_relations_for_event(
self,
event_id: str,
event: EventBase,
room_id: str,
relation_type: str,
ignored_users: FrozenSet[str] = frozenset(),
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of events which relate to an event, ordered by topological ordering.
Args:
event_id: Fetch events that relate to this event ID.
event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: The type of relation.
ignored_users: The users ignored by the requesting user.
Returns:
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
# Call the underlying storage method, which is cached.
related_events, next_token = await self._main_store.get_relations_for_event(
event_id, event, room_id, relation_type, direction="f"
)
# Filter out ignored users and convert to the expected format.
related_events = [
event for event in related_events if event.sender not in ignored_users
]
return related_events, next_token
async def get_annotations_for_event(
self,
event_id: str,
room_id: str,
limit: int = 5,
ignored_users: FrozenSet[str] = frozenset(),
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
ignored_users: The users ignored by the requesting user.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
# Get the base results for all users.
full_results = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id, limit
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
event_id, room_id, limit, ignored_users
)
filtered_results = []
for result in full_results:
key = (result["type"], result["key"])
if key in ignored_results:
result = result.copy()
result["count"] -= ignored_results[key]
if result["count"] <= 0:
continue
filtered_results.append(result)
return filtered_results
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
self, event: EventBase, ignored_users: FrozenSet[str]
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
@ -167,7 +265,7 @@ class RelationsHandler:
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@ -190,23 +288,125 @@ class RelationsHandler:
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
annotations = await self.get_annotations_for_event(
event_id, room_id, ignored_users=ignored_users
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
if annotations:
aggregations.annotations = {"chunk": annotations}
references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
references, next_token = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.REFERENCE,
ignored_users=ignored_users,
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
if references:
aggregations.references = {
"chunk": [{"event_id": event.event_id} for event in references]
}
if next_token:
aggregations.references["next_batch"] = await next_token.to_string(
self._main_store
)
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
event_ids: Events to get aggregations for threads.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
A dictionary mapping event ID to the thread information.
May not contain a value for all requested event IDs.
"""
user = UserID.from_string(user_id)
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
# A map of event ID to the thread aggregation.
results = {}
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
# Subtract off the count of any ignored users.
for ignored_user in ignored_users:
thread_count -= ignored_results.get((event_id, ignored_user), 0)
# This is gnarly, but if the latest event is from an ignored user,
# attempt to find one that isn't from an ignored user.
if latest_thread_event.sender in ignored_users:
room_id = latest_thread_event.room_id
# If the root event is not found, something went wrong, do
# not include a summary of the thread.
event = await self._event_handler.get_event(user, room_id, event_id)
if event is None:
continue
potential_events, _ = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.THREAD,
ignored_users,
)
# If all found events are from ignored users, do not include
# a summary of the thread.
if not potential_events:
continue
# The *last* event returned is the one that is cared about.
event = await self._event_handler.get_event(
user, room_id, potential_events[-1].event_id
)
# It is unexpected that the event will not exist.
if event is None:
logger.warning(
"Unable to fetch latest event in a thread with event ID: %s",
potential_events[-1].event_id,
)
continue
latest_thread_event = event
results[event_id] = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
@ -230,13 +430,21 @@ class RelationsHandler:
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch any ignored users of the requesting user.
ignored_users = await self._main_store.ignored_users(user_id)
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
event_result = await self._get_bundled_aggregation_for_event(
event, ignored_users
)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
#
# Note that there is no use in limiting edits by ignored users since the
# parent event should be ignored in the first place if the user is ignored.
edits = await self._main_store.get_applicable_edits(
[
event_id
@ -247,25 +455,10 @@ class RelationsHandler:
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
threads = await self.get_threads_for_events(
events_by_id.keys(), user_id, ignored_users
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
for event_id, thread in threads.items():
results.setdefault(event_id, BundledAggregations()).thread = thread
return results

View file

@ -771,7 +771,9 @@ class RoomCreationHandler:
% (user_id,),
)
visibility = config.get("visibility", None)
# The spec says rooms should default to private visibility if
# `visibility` is not specified.
visibility = config.get("visibility", "private")
is_public = visibility == "public"
if "room_id" in config:
@ -891,7 +893,7 @@ class RoomCreationHandler:
#
# we also don't need to check the requester's shadow-ban here, as we
# have already done so above (and potentially emptied invite_list).
with (await self.room_member_handler.member_linearizer.queue((room_id,))):
async with self.room_member_handler.member_linearizer.queue((room_id,)):
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
@ -1452,8 +1454,8 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def get_current_key(self) -> RoomStreamToken:
return self.store.get_room_max_token()
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
def get_current_key_for_room(self, room_id: str) -> Awaitable[RoomStreamToken]:
return self.store.get_current_room_stream_token_for_room_id(room_id)
class ShutdownRoomResponse(TypedDict):

View file

@ -158,8 +158,8 @@ class RoomBatchHandler:
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
them in a floating state event chain which don't resolve into the current room
state. They are floating because they reference no prev_events and are marked
as outliers which disconnects them from the normal DAG.
state. They are floating because they reference no prev_events which disconnects
them from the normal DAG.
Args:
state_events_at_start:
@ -215,31 +215,23 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
# The first event in the state chain is floating with no
# `prev_events` which means it can't derive state from
# anywhere automatically. So we need to set some state
# explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
# reference and also update in the event when we append
# later.
state_event_ids=state_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
# and will use this code path. Maybe we only want to support join state events
# and can get rid of this `else`?
(
event,
_,
@ -248,21 +240,15 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
# The first event in the state chain is floating with no
# `prev_events` which means it can't derive state from
# anywhere automatically. So we need to set some state
# explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same

View file

@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We first linearise by the application service (to try to limit concurrent joins
# by application services), and then by room ID.
with (await self.member_as_limiter.queue(as_id)):
with (await self.member_linearizer.queue(key)):
async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key):
result = await self.update_membership_locked(
requester,
target,

View file

@ -59,8 +59,6 @@ class SearchHandler:
self.state_store = self.storage.state
self.auth = hs.get_auth()
self._msc3666_enabled = hs.config.experimental.msc3666_enabled
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
@ -353,22 +351,20 @@ class SearchHandler:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = None
if self._msc3666_enabled:
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
# The returned events.
search_result.allowed_events,
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
user.to_string(),
)
# The returned events.
search_result.allowed_events,
),
user.to_string(),
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.

View file

@ -430,7 +430,7 @@ class SsoHandler:
# grab a lock while we try to find a mapping for this user. This seems...
# optimistic, especially for implementations that end up redirecting to
# interstitial pages.
with await self._mapping_lock.queue(auth_provider_id):
async with self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id,

View file

@ -13,17 +13,7 @@
# limitations under the License.
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
FrozenSet,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
DeviceListUpdates,
JsonDict,
MutableStateMap,
Requester,
@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: List of user_ids whose devices may have changed
left: List of user_ids whose devices we no longer track
"""
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@ -240,7 +216,7 @@ class SyncResult:
knocked: List[KnockedSyncResult]
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
device_lists: DeviceLists
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
groups: Optional[GroupsSyncResult]
@ -298,6 +274,8 @@ class SyncHandler:
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user(
self,
requester: Requester,
@ -1176,8 +1154,9 @@ class SyncHandler:
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
if self.hs_config.experimental.groups_enabled:
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
num_events = 0
@ -1261,8 +1240,8 @@ class SyncHandler:
newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
"""Generate the DeviceLists section of sync
) -> DeviceListUpdates:
"""Generate the DeviceListUpdates section of sync
Args:
sync_result_builder
@ -1380,9 +1359,11 @@ class SyncHandler:
if any(e.room_id in joined_rooms for e in entries):
newly_left_users.discard(user_id)
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
return DeviceListUpdates(
changed=users_that_have_changed, left=newly_left_users
)
else:
return DeviceLists(changed=[], left=[])
return DeviceListUpdates()
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
@ -1606,13 +1587,15 @@ class SyncHandler:
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
room_changes = await self._get_all_rooms(
sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@ -1688,7 +1671,10 @@ class SyncHandler:
return False
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
excluded_rooms: List[str],
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
@ -1720,7 +1706,7 @@ class SyncHandler:
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
user_id, since_token.room_key, now_token.room_key, excluded_rooms
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@ -1864,6 +1850,7 @@ class SyncHandler:
full_state=False,
since_token=since_token,
upto_token=leave_token,
out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
)
)
@ -1921,7 +1908,10 @@ class SyncHandler:
)
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
ignored_rooms: List[str],
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@ -1932,7 +1922,7 @@ class SyncHandler:
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
ignored_rooms: List of rooms to ignore.
"""
user_id = sync_result_builder.sync_config.user.to_string()
@ -1943,6 +1933,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id,
membership_list=Membership.LIST,
excluded_rooms=ignored_rooms,
)
room_entries = []
@ -2126,33 +2117,41 @@ class SyncHandler:
):
return
state = await self.compute_state_delta(
room_id,
batch,
sync_config,
since_token,
now_token,
full_state=full_state,
)
if not room_builder.out_of_band:
state = await self.compute_state_delta(
room_id,
batch,
sync_config,
since_token,
now_token,
full_state=full_state,
)
else:
# An out of band room won't have any state changes.
state = {}
summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
if sync_config.filter_collection.lazy_load_members() and (
# we recalculate the summary:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events)
or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited
and any(t == EventTypes.Member for (t, k) in state)
if (
not room_builder.out_of_band
and sync_config.filter_collection.lazy_load_members()
and (
# we recalculate the summary:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events)
or (
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited
and any(t == EventTypes.Member for (t, k) in state)
)
or since_token is None
)
or since_token is None
):
summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token
@ -2396,6 +2395,8 @@ class RoomSyncResultBuilder:
full_state: Whether the full state should be sent in result
since_token: Earliest point to return events from, or None
upto_token: Latest point to return events from.
out_of_band: whether the events in the room are "out of band" events
and the server isn't in the room.
"""
room_id: str
@ -2405,3 +2406,5 @@ class RoomSyncResultBuilder:
full_state: bool
since_token: Optional[StreamToken]
upto_token: StreamToken
out_of_band: bool = False

View file

@ -107,6 +107,8 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
assert self._secret is not None
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={

View file

@ -22,7 +22,6 @@ from typing import (
BinaryIO,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
@ -72,6 +71,7 @@ from twisted.web.iweb import (
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
@ -97,10 +97,6 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the entries can either be Lists or bytes.
RawHeaderValue = Sequence[Union[str, bytes]]
# the type of the query params, to be passed into `urlencode`
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
@ -912,7 +908,7 @@ def read_body_with_max_size(
return d
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
@ -925,13 +921,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
if args is None:
return b""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]
query_str = urllib.parse.urlencode(encoded_args, True)
query_str = urllib.parse.urlencode(args, True)
return query_str.encode("utf8")

View file

@ -67,6 +67,7 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
@ -98,10 +99,6 @@ MAXINT = sys.maxsize
_next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]]
T = TypeVar("T")
@ -144,7 +141,7 @@ class MatrixFederationRequest:
"""A callback to generate the JSON.
"""
query: Optional[dict] = None
query: Optional[QueryParams] = None
"""Query arguments.
"""
@ -165,10 +162,7 @@ class MatrixFederationRequest:
destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
if self.query:
query_bytes = encode_query_args(self.query)
else:
query_bytes = b""
query_bytes = encode_query_args(self.query)
# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
@ -485,10 +479,7 @@ class MatrixFederationHttpClient:
method_bytes = request.method.encode("ascii")
destination_bytes = request.destination.encode("ascii")
path_bytes = request.path.encode("ascii")
if request.query:
query_bytes = encode_query_args(request.query)
else:
query_bytes = b""
query_bytes = encode_query_args(request.query)
scope = start_active_span(
"outgoing-federation-request",
@ -746,7 +737,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@ -764,7 +755,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@ -781,7 +772,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
@ -891,7 +882,7 @@ class MatrixFederationHttpClient:
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Sends the specified json data using POST
@ -961,7 +952,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
@ -976,7 +967,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = ...,
args: Optional[QueryParams] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
@ -990,7 +981,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
@ -1085,7 +1076,7 @@ class MatrixFederationHttpClient:
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
@ -1150,7 +1141,7 @@ class MatrixFederationHttpClient:
destination: str,
path: str,
output_stream,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,

21
synapse/http/types.py Normal file
View file

@ -0,0 +1,21 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Iterable, Mapping, Union
# the type of the query params, to be passed into `urlencode` with `doseq=True`.
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
__all__ = ["QueryParams"]

View file

@ -289,6 +289,9 @@ class SynapseTags:
# Uniqueish ID of a database transaction
DB_TXN_ID = "db.txn_id"
# The name of the external cache
CACHE_NAME = "cache.name"
class SynapseBaggage:
FORCE_TRACING = "synapse-force-tracing"

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -52,12 +53,13 @@ from synapse.metrics._exposition import (
start_http_server,
)
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._types import Collector
logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
all_gauges: Dict[str, Collector] = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@ -78,11 +80,10 @@ RegistryProxy = cast(CollectorRegistry, _RegistryProxy)
@attr.s(slots=True, hash=True, auto_attribs=True)
class LaterGauge:
class LaterGauge(Collector):
name: str
desc: str
labels: Optional[Iterable[str]] = attr.ib(hash=False)
labels: Optional[Sequence[str]] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@ -125,7 +126,7 @@ class LaterGauge:
MetricsEntry = TypeVar("MetricsEntry")
class InFlightGauge(Generic[MetricsEntry]):
class InFlightGauge(Generic[MetricsEntry], Collector):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
@ -246,7 +247,7 @@ class InFlightGauge(Generic[MetricsEntry]):
all_gauges[self.name] = self
class GaugeBucketCollector:
class GaugeBucketCollector(Collector):
"""Like a Histogram, but the buckets are Gauges which are updated atomically.
The data is updated by calling `update_data` with an iterable of measurements.
@ -340,7 +341,7 @@ class GaugeBucketCollector:
#
class CPUMetrics:
class CPUMetrics(Collector):
def __init__(self) -> None:
ticks_per_sec = 100
try:
@ -470,6 +471,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
__all__ = [
"Collector",
"MetricsResource",
"generate_latest",
"start_http_server",

View file

@ -30,6 +30,8 @@ from prometheus_client.core import (
from twisted.internet import task
from synapse.metrics._types import Collector
"""Prometheus metrics for garbage collection"""
@ -71,7 +73,7 @@ gc_time = Histogram(
)
class GCCounts:
class GCCounts(Collector):
def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
@ -135,7 +137,7 @@ def install_gc_manager() -> None:
#
class PyPyGCStats:
class PyPyGCStats(Collector):
def collect(self) -> Iterable[Metric]:
# @stats is a pretty-printer object with __str__() returning a nice table,

View file

@ -21,6 +21,8 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily
from twisted.internet import reactor
from synapse.metrics._types import Collector
#
# Twisted reactor metrics
#
@ -54,7 +56,7 @@ class EpollWrapper:
return getattr(self._poller, item)
class ReactorLastSeenMetric:
class ReactorLastSeenMetric(Collector):
def __init__(self, epoll_wrapper: EpollWrapper):
self._epoll_wrapper = epoll_wrapper

31
synapse/metrics/_types.py Normal file
View file

@ -0,0 +1,31 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Iterable
from prometheus_client import Metric
try:
from prometheus_client.registry import Collector
except ImportError:
# prometheus_client.Collector is new as of prometheus 0.14. We redefine it here
# for compatibility with earlier versions.
class _Collector(ABC):
@abstractmethod
def collect(self) -> Iterable[Metric]:
pass
Collector = _Collector # type: ignore

View file

@ -46,6 +46,7 @@ from synapse.logging.opentracing import (
noop_context_manager,
start_active_span,
)
from synapse.metrics._types import Collector
if TYPE_CHECKING:
import resource
@ -127,7 +128,7 @@ _background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set(
_bg_metrics_lock = threading.Lock()
class _Collector:
class _Collector(Collector):
"""A custom metrics collector for the background process metrics.
Ensures that all of the metrics are up-to-date with any in-flight processes

View file

@ -16,11 +16,13 @@ import ctypes
import logging
import os
import re
from typing import Iterable, Optional
from typing import Iterable, Optional, overload
from prometheus_client import Metric
from prometheus_client import REGISTRY, Metric
from typing_extensions import Literal
from synapse.metrics import REGISTRY, GaugeMetricFamily
from synapse.metrics import GaugeMetricFamily
from synapse.metrics._types import Collector
logger = logging.getLogger(__name__)
@ -59,6 +61,16 @@ def _setup_jemalloc_stats() -> None:
jemalloc = ctypes.CDLL(jemalloc_path)
@overload
def _mallctl(
name: str, read: Literal[True] = True, write: Optional[int] = None
) -> int:
...
@overload
def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
...
def _mallctl(
name: str, read: bool = True, write: Optional[int] = None
) -> Optional[int]:
@ -134,7 +146,7 @@ def _setup_jemalloc_stats() -> None:
except Exception as e:
logger.warning("Failed to reload jemalloc stats: %s", e)
class JemallocCollector:
class JemallocCollector(Collector):
"""Metrics for internal jemalloc stats."""
def collect(self) -> Iterable[Metric]:

Some files were not shown because too many files have changed in this diff Show more