"]
license = "Apache-2.0"
@@ -158,6 +158,9 @@ packaging = ">=16.1"
# At the time of writing, we only use functions from the version `importlib.metadata`
# which shipped in Python 3.8. This corresponds to version 1.4 of the backport.
importlib_metadata = { version = ">=1.4", python = "<3.8" }
+# This is the most recent version of Pydantic with available on common distros.
+pydantic = ">=1.7.4"
+
# Optional Dependencies
diff --git a/requirements.txt b/requirements.txt
index 4c87bd4bb..7a6c0398e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -497,6 +497,42 @@ pyasn1==0.4.8 \
pycparser==2.21 ; python_full_version >= "3.6.7" and platform_python_implementation == "PyPy" \
--hash=sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9 \
--hash=sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206
+pydantic==1.9.1 \
+ --hash=sha256:c8098a724c2784bf03e8070993f6d46aa2eeca031f8d8a048dff277703e6e193 \
+ --hash=sha256:c320c64dd876e45254bdd350f0179da737463eea41c43bacbee9d8c9d1021f11 \
+ --hash=sha256:18f3e912f9ad1bdec27fb06b8198a2ccc32f201e24174cec1b3424dda605a310 \
+ --hash=sha256:c11951b404e08b01b151222a1cb1a9f0a860a8153ce8334149ab9199cd198131 \
+ --hash=sha256:8bc541a405423ce0e51c19f637050acdbdf8feca34150e0d17f675e72d119580 \
+ --hash=sha256:e565a785233c2d03724c4dc55464559639b1ba9ecf091288dd47ad9c629433bd \
+ --hash=sha256:a4a88dcd6ff8fd47c18b3a3709a89adb39a6373f4482e04c1b765045c7e282fd \
+ --hash=sha256:447d5521575f18e18240906beadc58551e97ec98142266e521c34968c76c8761 \
+ --hash=sha256:985ceb5d0a86fcaa61e45781e567a59baa0da292d5ed2e490d612d0de5796918 \
+ --hash=sha256:059b6c1795170809103a1538255883e1983e5b831faea6558ef873d4955b4a74 \
+ --hash=sha256:d12f96b5b64bec3f43c8e82b4aab7599d0157f11c798c9f9c528a72b9e0b339a \
+ --hash=sha256:ae72f8098acb368d877b210ebe02ba12585e77bd0db78ac04a1ee9b9f5dd2166 \
+ --hash=sha256:79b485767c13788ee314669008d01f9ef3bc05db9ea3298f6a50d3ef596a154b \
+ --hash=sha256:494f7c8537f0c02b740c229af4cb47c0d39840b829ecdcfc93d91dcbb0779892 \
+ --hash=sha256:f0f047e11febe5c3198ed346b507e1d010330d56ad615a7e0a89fae604065a0e \
+ --hash=sha256:969dd06110cb780da01336b281f53e2e7eb3a482831df441fb65dd30403f4608 \
+ --hash=sha256:177071dfc0df6248fd22b43036f936cfe2508077a72af0933d0c1fa269b18537 \
+ --hash=sha256:9bcf8b6e011be08fb729d110f3e22e654a50f8a826b0575c7196616780683380 \
+ --hash=sha256:a955260d47f03df08acf45689bd163ed9df82c0e0124beb4251b1290fa7ae728 \
+ --hash=sha256:9ce157d979f742a915b75f792dbd6aa63b8eccaf46a1005ba03aa8a986bde34a \
+ --hash=sha256:0bf07cab5b279859c253d26a9194a8906e6f4a210063b84b433cf90a569de0c1 \
+ --hash=sha256:5d93d4e95eacd313d2c765ebe40d49ca9dd2ed90e5b37d0d421c597af830c195 \
+ --hash=sha256:1542636a39c4892c4f4fa6270696902acb186a9aaeac6f6cf92ce6ae2e88564b \
+ --hash=sha256:a9af62e9b5b9bc67b2a195ebc2c2662fdf498a822d62f902bf27cccb52dbbf49 \
+ --hash=sha256:fe4670cb32ea98ffbf5a1262f14c3e102cccd92b1869df3bb09538158ba90fe6 \
+ --hash=sha256:9f659a5ee95c8baa2436d392267988fd0f43eb774e5eb8739252e5a7e9cf07e0 \
+ --hash=sha256:b83ba3825bc91dfa989d4eed76865e71aea3a6ca1388b59fc801ee04c4d8d0d6 \
+ --hash=sha256:1dd8fecbad028cd89d04a46688d2fcc14423e8a196d5b0a5c65105664901f810 \
+ --hash=sha256:02eefd7087268b711a3ff4db528e9916ac9aa18616da7bca69c1871d0b7a091f \
+ --hash=sha256:7eb57ba90929bac0b6cc2af2373893d80ac559adda6933e562dcfb375029acee \
+ --hash=sha256:4ce9ae9e91f46c344bec3b03d6ee9612802682c1551aaf627ad24045ce090761 \
+ --hash=sha256:72ccb318bf0c9ab97fc04c10c37683d9eea952ed526707fabf9ac5ae59b701fd \
+ --hash=sha256:61b6760b08b7c395975d893e0b814a11cf011ebb24f7d869e7118f5a339a82e1 \
+ --hash=sha256:4988c0f13c42bfa9ddd2fe2f569c9d54646ce84adc5de84228cfe83396f3bd58 \
+ --hash=sha256:1ed987c3ff29fff7fd8c3ea3a3ea877ad310aae2ef9889a119e22d3f2db0691a
pymacaroons==0.13.0 \
--hash=sha256:3e14dff6a262fdbf1a15e769ce635a8aea72e6f8f91e408f9a97166c53b91907 \
--hash=sha256:1e6bba42a5f66c245adf38a5a4006a99dcc06a0703786ea636098667d42903b8
diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py
new file mode 100755
index 000000000..d0fb811bd
--- /dev/null
+++ b/scripts-dev/check_pydantic_models.py
@@ -0,0 +1,425 @@
+#! /usr/bin/env python
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A script which enforces that Synapse always uses strict types when defining a Pydantic
+model.
+
+Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See
+
+ https://github.com/pydantic/pydantic/issues/1098
+ https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode
+
+until then, this script is a best effort to stop us from introducing type coersion bugs
+(like the infamous stringy power levels fixed in room version 10).
+"""
+import argparse
+import contextlib
+import functools
+import importlib
+import logging
+import os
+import pkgutil
+import sys
+import textwrap
+import traceback
+import unittest.mock
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar
+
+from parameterized import parameterized
+from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
+from pydantic.typing import get_args
+from typing_extensions import ParamSpec
+
+logger = logging.getLogger(__name__)
+
+CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
+ constr,
+ conbytes,
+ conint,
+ confloat,
+]
+
+TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
+ str,
+ bytes,
+ int,
+ float,
+ bool,
+]
+
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class ModelCheckerException(Exception):
+ """Dummy exception. Allows us to detect unwanted types during a module import."""
+
+
+class MissingStrictInConstrainedTypeException(ModelCheckerException):
+ factory_name: str
+
+ def __init__(self, factory_name: str):
+ self.factory_name = factory_name
+
+
+class FieldHasUnwantedTypeException(ModelCheckerException):
+ message: str
+
+ def __init__(self, message: str):
+ self.message = message
+
+
+def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
+ """We patch `constr` and friends with wrappers that enforce strict=True."""
+
+ @functools.wraps(factory)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668
+ if "strict" not in kwargs: # type: ignore[attr-defined]
+ raise MissingStrictInConstrainedTypeException(factory.__name__)
+ if not kwargs["strict"]: # type: ignore[index]
+ raise MissingStrictInConstrainedTypeException(factory.__name__)
+ return factory(*args, **kwargs)
+
+ return wrapper
+
+
+def field_type_unwanted(type_: Any) -> bool:
+ """Very rough attempt to detect if a type is unwanted as a Pydantic annotation.
+
+ At present, we exclude types which will coerce, or any generic type involving types
+ which will coerce."""
+ logger.debug("Is %s unwanted?")
+ if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
+ logger.debug("yes")
+ return True
+ logger.debug("Maybe. Subargs are %s", get_args(type_))
+ rv = any(field_type_unwanted(t) for t in get_args(type_))
+ logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
+ return rv
+
+
+class PatchedBaseModel(PydanticBaseModel):
+ """A patched version of BaseModel that inspects fields after models are defined.
+
+ We complain loudly if we see an unwanted type.
+
+ Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
+ """
+
+ @classmethod
+ def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
+ for field in cls.__fields__.values():
+ # Note that field.type_ and field.outer_type are computed based on the
+ # annotation type, see pydantic.fields.ModelField._type_analysis
+ if field_type_unwanted(field.outer_type_):
+ # TODO: this only reports the first bad field. Can we find all bad ones
+ # and report them all?
+ raise FieldHasUnwantedTypeException(
+ f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
+ f"with unwanted type `{field.outer_type_}`"
+ )
+
+
+@contextmanager
+def monkeypatch_pydantic() -> Generator[None, None, None]:
+ """Patch pydantic with our snooping versions of BaseModel and the con* functions.
+
+ If the snooping functions see something they don't like, they'll raise a
+ ModelCheckingException instance.
+ """
+ with contextlib.ExitStack() as patches:
+ # Most Synapse code ought to import the patched objects directly from
+ # `pydantic`. But we also patch their containing modules `pydantic.main` and
+ # `pydantic.types` for completeness.
+ patch_basemodel1 = unittest.mock.patch(
+ "pydantic.BaseModel", new=PatchedBaseModel
+ )
+ patch_basemodel2 = unittest.mock.patch(
+ "pydantic.main.BaseModel", new=PatchedBaseModel
+ )
+ patches.enter_context(patch_basemodel1)
+ patches.enter_context(patch_basemodel2)
+ for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
+ wrapper: Callable = make_wrapper(factory)
+ patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
+ patch2 = unittest.mock.patch(
+ f"pydantic.types.{factory.__name__}", new=wrapper
+ )
+ patches.enter_context(patch1)
+ patches.enter_context(patch2)
+ yield
+
+
+def format_model_checker_exception(e: ModelCheckerException) -> str:
+ """Work out which line of code caused e. Format the line in a human-friendly way."""
+ # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
+ # patches of constr() etc, and instead inspect fields to look for ConstrainedStr
+ # with strict=False? There is some difficulty with the inheritance hierarchy
+ # because StrictStr < ConstrainedStr < str.
+ if isinstance(e, FieldHasUnwantedTypeException):
+ return e.message
+ elif isinstance(e, MissingStrictInConstrainedTypeException):
+ frame_summary = traceback.extract_tb(e.__traceback__)[-2]
+ return (
+ f"Missing `strict=True` from {e.factory_name}() call \n"
+ + traceback.format_list([frame_summary])[0].lstrip()
+ )
+ else:
+ raise ValueError(f"Unknown exception {e}") from e
+
+
+def lint() -> int:
+ """Try to import all of Synapse and see if we spot any Pydantic type coercions.
+
+ Print any problems, then return a status code suitable for sys.exit."""
+ failures = do_lint()
+ if failures:
+ print(f"Found {len(failures)} problem(s)")
+ for failure in sorted(failures):
+ print(failure)
+ return os.EX_DATAERR if failures else os.EX_OK
+
+
+def do_lint() -> Set[str]:
+ """Try to import all of Synapse and see if we spot any Pydantic type coercions."""
+ failures = set()
+
+ with monkeypatch_pydantic():
+ logger.debug("Importing synapse")
+ try:
+ # TODO: make "synapse" an argument so we can target this script at
+ # a subpackage
+ module = importlib.import_module("synapse")
+ except ModelCheckerException as e:
+ logger.warning("Bad annotation found when importing synapse")
+ failures.add(format_model_checker_exception(e))
+ return failures
+
+ try:
+ logger.debug("Fetching subpackages")
+ module_infos = list(
+ pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
+ )
+ except ModelCheckerException as e:
+ logger.warning("Bad annotation found when looking for modules to import")
+ failures.add(format_model_checker_exception(e))
+ return failures
+
+ for module_info in module_infos:
+ logger.debug("Importing %s", module_info.name)
+ try:
+ importlib.import_module(module_info.name)
+ except ModelCheckerException as e:
+ logger.warning(
+ f"Bad annotation found when importing {module_info.name}"
+ )
+ failures.add(format_model_checker_exception(e))
+
+ return failures
+
+
+def run_test_snippet(source: str) -> None:
+ """Exec a snippet of source code in an isolated environment."""
+ # To emulate `source` being called at the top level of the module,
+ # the globals and locals we provide apparently have to be the same mapping.
+ #
+ # > Remember that at the module level, globals and locals are the same dictionary.
+ # > If exec gets two separate objects as globals and locals, the code will be
+ # > executed as if it were embedded in a class definition.
+ globals_: Dict[str, object]
+ locals_: Dict[str, object]
+ globals_ = locals_ = {}
+ exec(textwrap.dedent(source), globals_, locals_)
+
+
+class TestConstrainedTypesPatch(unittest.TestCase):
+ def test_expression_without_strict_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import constr
+ constr()
+ """
+ )
+
+ def test_called_as_module_attribute_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ import pydantic
+ pydantic.constr()
+ """
+ )
+
+ def test_wildcard_import_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import *
+ constr()
+ """
+ )
+
+ def test_alternative_import_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic.types import constr
+ constr()
+ """
+ )
+
+ def test_alternative_import_attribute_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ import pydantic.types
+ pydantic.types.constr()
+ """
+ )
+
+ def test_kwarg_but_no_strict_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import constr
+ constr(min_length=10)
+ """
+ )
+
+ def test_kwarg_strict_False_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import constr
+ constr(strict=False)
+ """
+ )
+
+ def test_kwarg_strict_True_doesnt_raise(self) -> None:
+ with monkeypatch_pydantic():
+ run_test_snippet(
+ """
+ from pydantic import constr
+ constr(strict=True)
+ """
+ )
+
+ def test_annotation_without_strict_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import constr
+ x: constr()
+ """
+ )
+
+ def test_field_annotation_without_strict_raises(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic import BaseModel, conint
+ class C:
+ x: conint()
+ """
+ )
+
+
+class TestFieldTypeInspection(unittest.TestCase):
+ @parameterized.expand(
+ [
+ ("str",),
+ ("bytes"),
+ ("int",),
+ ("float",),
+ ("bool"),
+ ("Optional[str]",),
+ ("Union[None, str]",),
+ ("List[str]",),
+ ("List[List[str]]",),
+ ("Dict[StrictStr, str]",),
+ ("Dict[str, StrictStr]",),
+ ("TypedDict('D', x=int)",),
+ ]
+ )
+ def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ f"""
+ from typing import *
+ from pydantic import *
+ class C(BaseModel):
+ f: {annotation}
+ """
+ )
+
+ @parameterized.expand(
+ [
+ ("StrictStr",),
+ ("StrictBytes"),
+ ("StrictInt",),
+ ("StrictFloat",),
+ ("StrictBool"),
+ ("constr(strict=True, min_length=10)",),
+ ("Optional[StrictStr]",),
+ ("Union[None, StrictStr]",),
+ ("List[StrictStr]",),
+ ("List[List[StrictStr]]",),
+ ("Dict[StrictStr, StrictStr]",),
+ ("TypedDict('D', x=StrictInt)",),
+ ]
+ )
+ def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None:
+ with monkeypatch_pydantic():
+ run_test_snippet(
+ f"""
+ from typing import *
+ from pydantic import *
+ class C(BaseModel):
+ f: {annotation}
+ """
+ )
+
+ def test_field_holding_str_raises_with_alternative_import(self) -> None:
+ with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
+ run_test_snippet(
+ """
+ from pydantic.main import BaseModel
+ class C(BaseModel):
+ f: str
+ """
+ )
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?")
+parser.add_argument("-v", "--verbose", action="store_true")
+
+
+if __name__ == "__main__":
+ args = parser.parse_args(sys.argv[1:])
+ logging.basicConfig(
+ format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s",
+ level=logging.DEBUG if args.verbose else logging.INFO,
+ )
+ # suppress logs we don't care about
+ logging.getLogger("xmlschema").setLevel(logging.WARNING)
+ if args.mode == "lint":
+ sys.exit(lint())
+ elif args.mode == "test":
+ unittest.main(argv=sys.argv[:1])
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 377348b10..bf900645b 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -106,4 +106,5 @@ isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
flake8 "${files[@]}"
+./scripts-dev/check_pydantic_models.py lint
mypy
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6ae45ac1f..d11e55467 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -37,8 +37,7 @@ from synapse.logging.opentracing import (
start_active_span,
trace,
)
-from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import Requester, create_requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -70,14 +69,14 @@ class Auth:
async def check_user_in_room(
self,
room_id: str,
- user_id: str,
+ requester: Requester,
allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
- user_id: The user to check.
+ requester: The user making the request, according to the access token.
current_state: Optional map of the current state of the room.
If provided then that map is used to check whether they are a
@@ -94,6 +93,7 @@ class Auth:
membership event ID of the user.
"""
+ user_id = requester.user.to_string()
(
membership,
member_event_id,
@@ -182,96 +182,69 @@ class Auth:
access_token = self.get_access_token_from_request(request)
- (
- user_id,
- device_id,
- app_service,
- ) = await self._get_appservice_user_id_and_device_id(request)
- if user_id and app_service:
- if ip_addr and self._track_appservice_user_ips:
- await self.store.insert_client_ip(
- user_id=user_id,
- access_token=access_token,
- ip=ip_addr,
- user_agent=user_agent,
- device_id="dummy-device"
- if device_id is None
- else device_id, # stubbed
- )
-
- requester = create_requester(
- user_id, app_service=app_service, device_id=device_id
+ # First check if it could be a request from an appservice
+ requester = await self._get_appservice_user(request)
+ if not requester:
+ # If not, it should be from a regular user
+ requester = await self.get_user_by_access_token(
+ access_token, allow_expired=allow_expired
)
- request.requester = user_id
- return requester
+ # Deny the request if the user account has expired.
+ # This check is only done for regular users, not appservice ones.
+ if not allow_expired:
+ if await self._account_validity_handler.is_user_expired(
+ requester.user.to_string()
+ ):
+ # Raise the error if either an account validity module has determined
+ # the account has expired, or the legacy account validity
+ # implementation is enabled and determined the account has expired
+ raise AuthError(
+ 403,
+ "User account has expired",
+ errcode=Codes.EXPIRED_ACCOUNT,
+ )
- user_info = await self.get_user_by_access_token(
- access_token, allow_expired=allow_expired
- )
- token_id = user_info.token_id
- is_guest = user_info.is_guest
- shadow_banned = user_info.shadow_banned
-
- # Deny the request if the user account has expired.
- if not allow_expired:
- if await self._account_validity_handler.is_user_expired(
- user_info.user_id
- ):
- # Raise the error if either an account validity module has determined
- # the account has expired, or the legacy account validity
- # implementation is enabled and determined the account has expired
- raise AuthError(
- 403,
- "User account has expired",
- errcode=Codes.EXPIRED_ACCOUNT,
- )
-
- device_id = user_info.device_id
-
- if access_token and ip_addr:
+ if ip_addr and (
+ not requester.app_service or self._track_appservice_user_ips
+ ):
+ # XXX(quenting): I'm 95% confident that we could skip setting the
+ # device_id to "dummy-device" for appservices, and that the only impact
+ # would be some rows which whould not deduplicate in the 'user_ips'
+ # table during the transition
+ recorded_device_id = (
+ "dummy-device"
+ if requester.device_id is None and requester.app_service is not None
+ else requester.device_id
+ )
await self.store.insert_client_ip(
- user_id=user_info.token_owner,
+ user_id=requester.authenticated_entity,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
- device_id=device_id,
+ device_id=recorded_device_id,
)
+
# Track also the puppeted user client IP if enabled and the user is puppeting
if (
- user_info.user_id != user_info.token_owner
+ requester.user.to_string() != requester.authenticated_entity
and self._track_puppeted_user_ips
):
await self.store.insert_client_ip(
- user_id=user_info.user_id,
+ user_id=requester.user.to_string(),
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
- device_id=device_id,
+ device_id=requester.device_id,
)
- if is_guest and not allow_guest:
+ if requester.is_guest and not allow_guest:
raise AuthError(
403,
"Guest access not allowed",
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- # Mark the token as used. This is used to invalidate old refresh
- # tokens after some time.
- if not user_info.token_used and token_id is not None:
- await self.store.mark_access_token_as_used(token_id)
-
- requester = create_requester(
- user_info.user_id,
- token_id,
- is_guest,
- shadow_banned,
- device_id,
- app_service=app_service,
- authenticated_entity=user_info.token_owner,
- )
-
request.requester = requester
return requester
except KeyError:
@@ -309,9 +282,7 @@ class Auth:
403, "Application service has not registered this user (%s)" % user_id
)
- async def _get_appservice_user_id_and_device_id(
- self, request: Request
- ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+ async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
"""
Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request
@@ -326,15 +297,13 @@ class Auth:
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
Returns:
- 3-tuple of
- (user ID?, device ID?, application service?)
+ the application service `Requester` of that request
Postconditions:
- - If an application service is returned, so is a user ID
- - A user ID is never returned without an application service
- - A device ID is never returned without a user ID or an application service
- - The returned application service, if present, is permitted to control the
- returned user ID.
+ - The `app_service` field in the returned `Requester` is set
+ - The `user_id` field in the returned `Requester` is either the application
+ service sender or the controlled user set by the `user_id` URI parameter
+ - The returned application service is permitted to control the returned user ID.
- The returned device ID, if present, has been checked to be a valid device ID
for the returned user ID.
"""
@@ -344,12 +313,12 @@ class Auth:
self.get_access_token_from_request(request)
)
if app_service is None:
- return None, None, None
+ return None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist:
- return None, None, None
+ return None
# This will always be set by the time Twisted calls us.
assert request.args is not None
@@ -383,13 +352,15 @@ class Auth:
Codes.EXCLUSIVE,
)
- return effective_user_id, effective_device_id, app_service
+ return create_requester(
+ effective_user_id, app_service=app_service, device_id=effective_device_id
+ )
async def get_user_by_access_token(
self,
token: str,
allow_expired: bool = False,
- ) -> TokenLookupResult:
+ ) -> Requester:
"""Validate access token and get user_id from it
Args:
@@ -406,9 +377,9 @@ class Auth:
# First look in the database to see if the access token is present
# as an opaque token.
- r = await self.store.get_user_by_access_token(token)
- if r:
- valid_until_ms = r.valid_until_ms
+ user_info = await self.store.get_user_by_access_token(token)
+ if user_info:
+ valid_until_ms = user_info.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@@ -420,7 +391,20 @@ class Auth:
msg="Access token has expired", soft_logout=True
)
- return r
+ # Mark the token as used. This is used to invalidate old refresh
+ # tokens after some time.
+ await self.store.mark_access_token_as_used(user_info.token_id)
+
+ requester = create_requester(
+ user_id=user_info.user_id,
+ access_token_id=user_info.token_id,
+ is_guest=user_info.is_guest,
+ shadow_banned=user_info.shadow_banned,
+ device_id=user_info.device_id,
+ authenticated_entity=user_info.token_owner,
+ )
+
+ return requester
# If the token isn't found in the database, then it could still be a
# macaroon for a guest, so we check that here.
@@ -446,11 +430,12 @@ class Auth:
"Guest access token used for regular user"
)
- return TokenLookupResult(
+ return create_requester(
user_id=user_id,
is_guest=True,
# all guests get the same device id
device_id=GUEST_DEVICE_ID,
+ authenticated_entity=user_id,
)
except (
pymacaroons.exceptions.MacaroonException,
@@ -473,32 +458,33 @@ class Auth:
request.requester = create_requester(service.sender, app_service=service)
return service
- async def is_server_admin(self, user: UserID) -> bool:
+ async def is_server_admin(self, requester: Requester) -> bool:
"""Check if the given user is a local server admin.
Args:
- user: user to check
+ requester: The user making the request, according to the access token.
Returns:
True if the user is an admin
"""
- return await self.store.is_server_admin(user)
+ return await self.store.is_server_admin(requester.user)
- async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
+ async def check_can_change_room_list(
+ self, room_id: str, requester: Requester
+ ) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
Args:
- room_id
- user
+ room_id: The room to check.
+ requester: The user making the request, according to the access token.
"""
- is_admin = await self.is_server_admin(user)
+ is_admin = await self.is_server_admin(requester)
if is_admin:
return True
- user_id = user.to_string()
- await self.check_user_in_room(room_id, user_id)
+ await self.check_user_in_room(room_id, requester)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
@@ -517,7 +503,9 @@ class Auth:
send_level = event_auth.get_send_level(
EventTypes.CanonicalAlias, "", power_level_event
)
- user_level = event_auth.get_user_power_level(user_id, auth_events)
+ user_level = event_auth.get_user_power_level(
+ requester.user.to_string(), auth_events
+ )
return user_level >= send_level
@@ -575,16 +563,16 @@ class Auth:
@trace
async def check_user_in_room_or_world_readable(
- self, room_id: str, user_id: str, allow_departed_users: bool = False
+ self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
Args:
- room_id: room to check
- user_id: user to check
- allow_departed_users: if True, accept users that were previously
- members but have now departed
+ room_id: The room to check.
+ requester: The user making the request, according to the access token.
+ allow_departed_users: If True, accept users that were previously
+ members but have now departed.
Returns:
Resolves to the current membership of the user in the room and the
@@ -599,7 +587,7 @@ class Auth:
# * The user is a guest user, and has joined the room
# else it will throw.
return await self.check_user_in_room(
- room_id, user_id, allow_departed_users=allow_departed_users
+ room_id, requester, allow_departed_users=allow_departed_users
)
except AuthError:
visibility = await self._storage_controllers.state.get_current_state_event(
@@ -614,6 +602,6 @@ class Auth:
raise UnstableSpecAuthError(
403,
"User %s not in room %s, and room previews are disabled"
- % (user_id, room_id),
+ % (requester.user, room_id),
errcode=Codes.NOT_JOINED,
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 1d46fb0e4..c73aea622 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -216,11 +216,11 @@ class EventContentFields:
MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
- MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
+ MSC2716_NEXT_BATCH_ID: Final = "next_batch_id"
# Used on "batch" events to indicate which insertion event it connects to
- MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
+ MSC2716_BATCH_ID: Final = "batch_id"
# For "marker" events
- MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
+ MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference"
# The authorising user for joining a restricted room.
AUTHORISING_USER: Final = "join_authorised_via_users_server"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 00e81b3af..a0e4ab6db 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -269,24 +269,6 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=False,
msc3667_int_only_power_levels=False,
)
- MSC2716v3 = RoomVersion(
- "org.matrix.msc2716v3",
- RoomDisposition.UNSTABLE,
- EventFormatVersions.V3,
- StateResolutionVersions.V2,
- enforce_key_validity=True,
- special_case_aliases_auth=False,
- strict_canonicaljson=True,
- limit_notifications_power_levels=True,
- msc2176_redaction_rules=False,
- msc3083_join_rules=False,
- msc3375_redaction_rules=False,
- msc2403_knocking=True,
- msc2716_historical=True,
- msc2716_redactions=True,
- msc3787_knock_restricted_join_rule=False,
- msc3667_int_only_power_levels=False,
- )
MSC3787 = RoomVersion(
"org.matrix.msc3787",
RoomDisposition.UNSTABLE,
@@ -323,6 +305,24 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=True,
msc3667_int_only_power_levels=True,
)
+ MSC2716v4 = RoomVersion(
+ "org.matrix.msc2716v4",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc3375_redaction_rules=False,
+ msc2403_knocking=True,
+ msc2716_historical=True,
+ msc2716_redactions=True,
+ msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
+ )
KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
@@ -338,9 +338,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V7,
RoomVersions.V8,
RoomVersions.V9,
- RoomVersions.MSC2716v3,
RoomVersions.MSC3787,
RoomVersions.V10,
+ RoomVersions.MSC2716v4,
)
}
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 42d1f6d21..30e21d970 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -441,6 +441,13 @@ def start(config_options: List[str]) -> None:
"synapse.app.user_dir",
)
+ if config.experimental.faster_joins_enabled:
+ raise ConfigError(
+ "You have enabled the experimental `faster_joins` config option, but it is "
+ "not compatible with worker deployments yet. Please disable `faster_joins` "
+ "or run Synapse as a single process deployment instead."
+ )
+
synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 745e70414..d98012ade 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -220,7 +220,10 @@ class SynapseHomeServer(HomeServer):
resources.update({"/_matrix/consent": consent_resource})
if name == "federation":
- resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+ federation_resource: Resource = TransportLayerServer(self)
+ if compress:
+ federation_resource = gz_wrap(federation_resource)
+ resources.update({FEDERATION_PREFIX: federation_resource})
if name == "openid":
resources.update(
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index d1335e77c..b3972ede9 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -23,7 +23,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'account_validity' section. Support for this setting has been deprecated and will be
removed in a future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 7765c5b45..66a6dbf1f 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -53,7 +53,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'email' section. Support for this setting has been deprecated and will be removed in a
future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7d17c958b..c1ff41753 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -90,3 +90,6 @@ class ExperimentalConfig(Config):
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+ # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
+ self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 2178cbf98..a452cc3a4 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -26,7 +26,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
This server's configuration file is using the deprecated 'template_dir' setting in the
'sso' section. Support for this setting has been deprecated and will be removed in a
future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ac91c5eb5..71853caad 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -161,7 +161,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
- add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
+ add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 54ffbd817..987f6dad4 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,7 +61,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -235,6 +235,7 @@ class FederationClient(FederationBase):
)
@trace
+ @tag_args
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]:
@@ -337,6 +338,8 @@ class FederationClient(FederationBase):
return None
+ @trace
+ @tag_args
async def get_pdu(
self,
destinations: Iterable[str],
@@ -448,6 +451,8 @@ class FederationClient(FederationBase):
return event_copy
+ @trace
+ @tag_args
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> Tuple[List[str], List[str]]:
@@ -467,6 +472,23 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "state_event_ids",
+ str(state_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "state_event_ids.length",
+ str(len(state_event_ids)),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "auth_event_ids",
+ str(auth_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "auth_event_ids.length",
+ str(len(auth_event_ids)),
+ )
+
if not isinstance(state_event_ids, list) or not isinstance(
auth_event_ids, list
):
@@ -474,6 +496,8 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ @trace
+ @tag_args
async def get_room_state(
self,
destination: str,
@@ -533,6 +557,7 @@ class FederationClient(FederationBase):
return valid_state_events, valid_auth_events
+ @trace
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index db4b83a50..75fbc6073 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -61,7 +61,12 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
-from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.opentracing import (
+ log_kv,
+ start_active_span_from_edu,
+ tag_args,
+ trace,
+)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -547,6 +552,8 @@ class FederationServer(FederationBase):
return 200, resp
+ @trace
+ @tag_args
async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
@@ -569,6 +576,8 @@ class FederationServer(FederationBase):
return 200, resp
+ @trace
+ @tag_args
async def _on_state_ids_request_compute(
self, room_id: str, event_id: str
) -> JsonDict:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bfa553504..0327fc57a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -280,7 +280,7 @@ class AuthHandler:
that it isn't stolen by re-authenticating them.
Args:
- requester: The user, as given by the access token
+ requester: The user making the request, according to the access token.
request: The request sent by the client.
@@ -1435,20 +1435,25 @@ class AuthHandler:
access_token: access token to be deleted
"""
- user_info = await self.auth.get_user_by_access_token(access_token)
+ token = await self.store.get_user_by_access_token(access_token)
+ if not token:
+ # At this point, the token should already have been fetched once by
+ # the caller, so this should not happen, unless of a race condition
+ # between two delete requests
+ raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
await self.store.delete_access_token(access_token)
# see if any modules want to know about this
await self.password_auth_provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
+ user_id=token.user_id,
+ device_id=token.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
- if user_info.token_id is not None:
+ if token.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_info.user_id, (user_info.token_id,)
+ token.user_id, (token.token_id,)
)
async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 1a8379854..f5c586f65 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -747,7 +748,13 @@ def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+ device.update(
+ {
+ "last_seen_user_agent": ip.get("user_agent"),
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ }
+ )
class DeviceListUpdater:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a0fb4dc2b..a1e1cd1c3 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService
from synapse.module_api import NOT_SPAM
from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -135,7 +135,7 @@ class DirectoryHandler:
else:
# Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room).
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = await self.store.get_rooms_for_user(user_id)
@@ -199,7 +199,7 @@ class DirectoryHandler:
user_id = requester.user.to_string()
try:
- can_delete = await self._user_can_delete_alias(room_alias, user_id)
+ can_delete = await self._user_can_delete_alias(room_alias, requester)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@@ -402,7 +402,9 @@ class DirectoryHandler:
# either no interested services, or no service with an exclusive lock
return True
- async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
+ async def _user_can_delete_alias(
+ self, alias: RoomAlias, requester: Requester
+ ) -> bool:
"""Determine whether a user can delete an alias.
One of the following must be true:
@@ -415,7 +417,7 @@ class DirectoryHandler:
"""
creator = await self.store.get_room_alias_creator(alias.to_string())
- if creator == user_id:
+ if creator == requester.user.to_string():
return True
# Resolve the alias to the corresponding room.
@@ -424,9 +426,7 @@ class DirectoryHandler:
if not room_id:
return False
- return await self.auth.check_can_change_room_list(
- room_id, UserID.from_string(user_id)
- )
+ return await self.auth.check_can_change_room_list(room_id, requester)
async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
@@ -465,7 +465,7 @@ class DirectoryHandler:
raise SynapseError(400, "Unknown room")
can_change_room_list = await self.auth.check_can_change_room_list(
- room_id, requester.user
+ room_id, requester
)
if not can_change_room_list:
raise AuthError(
@@ -530,10 +530,8 @@ class DirectoryHandler:
Get a list of the aliases that currently point to this room on this server
"""
# allow access to server admins and current members of the room
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
if not is_admin:
- await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string()
- )
+ await self.auth.check_user_in_room_or_world_readable(room_id, requester)
return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index aa1ac376f..a3098669f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,6 +32,7 @@ from typing import (
)
import attr
+from prometheus_client import Histogram
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -59,7 +60,7 @@ from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import nested_logging_context
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import NOT_SPAM
from synapse.replication.http.federation import (
@@ -79,6 +80,29 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Added to debug performance and track progress on optimizations
+backfill_processing_before_timer = Histogram(
+ "synapse_federation_backfill_processing_before_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.1,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 15.0,
+ 20.0,
+ 30.0,
+ 40.0,
+ 60.0,
+ 80.0,
+ "+Inf",
+ ),
+)
+
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
@@ -138,6 +162,7 @@ class FederationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -197,12 +222,39 @@ class FederationHandler:
return. This is used as part of the heuristic to decide if we
should back paginate.
"""
+ # Starting the processing time here so we can include the room backfill
+ # linearizer lock queue in the timing
+ processing_start_time = self.clock.time_msec()
+
async with self._room_backfill.queue(room_id):
- return await self._maybe_backfill_inner(room_id, current_depth, limit)
+ return await self._maybe_backfill_inner(
+ room_id,
+ current_depth,
+ limit,
+ processing_start_time=processing_start_time,
+ )
async def _maybe_backfill_inner(
- self, room_id: str, current_depth: int, limit: int
+ self,
+ room_id: str,
+ current_depth: int,
+ limit: int,
+ *,
+ processing_start_time: int,
) -> bool:
+ """
+ Checks whether the `current_depth` is at or approaching any backfill
+ points in the room and if so, will backfill. We only care about
+ checking backfill points that happened before the `current_depth`
+ (meaning less than or equal to the `current_depth`).
+
+ Args:
+ room_id: The room to backfill in.
+ current_depth: The depth to check at for any upcoming backfill points.
+ limit: The max number of events to request from the remote federated server.
+ processing_start_time: The time when `maybe_backfill` started
+ processing. Only used for timing.
+ """
backwards_extremities = [
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
@@ -370,6 +422,14 @@ class FederationHandler:
logger.debug(
"_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "extremities_to_request",
+ str(extremities_to_request),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+ str(len(extremities_to_request)),
+ )
# Now we need to decide which hosts to hit first.
@@ -425,6 +485,11 @@ class FederationHandler:
return False
+ processing_end_time = self.clock.time_msec()
+ backfill_processing_before_timer.observe(
+ (processing_end_time - processing_start_time) / 1000
+ )
+
success = await try_backfill(likely_domains)
if success:
return True
@@ -1081,6 +1146,8 @@ class FederationHandler:
return event
+ @trace
+ @tag_args
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 8968b705d..048c4111f 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -29,7 +29,7 @@ from typing import (
Tuple,
)
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
from synapse import event_auth
from synapse.api.constants import (
@@ -59,7 +59,13 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import nested_logging_context
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import (
+ SynapseTags,
+ set_tag,
+ start_active_span,
+ tag_args,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
@@ -92,6 +98,36 @@ soft_failed_event_counter = Counter(
"Events received over federation that we marked as soft_failed",
)
+# Added to debug performance and track progress on optimizations
+backfill_processing_after_timer = Histogram(
+ "synapse_federation_backfill_processing_after_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.1,
+ 0.25,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 7.5,
+ 10.0,
+ 15.0,
+ 20.0,
+ 25.0,
+ 30.0,
+ 40.0,
+ 50.0,
+ 60.0,
+ 80.0,
+ 100.0,
+ 120.0,
+ 150.0,
+ 180.0,
+ "+Inf",
+ ),
+)
+
class FederationEventHandler:
"""Handles events that originated from federation.
@@ -410,6 +446,7 @@ class FederationEventHandler:
prev_member_event,
)
+ @trace
async def process_remote_join(
self,
origin: str,
@@ -597,20 +634,21 @@ class FederationEventHandler:
if not events:
return
- # if there are any events in the wrong room, the remote server is buggy and
- # should not be trusted.
- for ev in events:
- if ev.room_id != room_id:
- raise InvalidResponseError(
- f"Remote server {dest} returned event {ev.event_id} which is in "
- f"room {ev.room_id}, when we were backfilling in {room_id}"
- )
+ with backfill_processing_after_timer.time():
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
+ )
- await self._process_pulled_events(
- dest,
- events,
- backfilled=True,
- )
+ await self._process_pulled_events(
+ dest,
+ events,
+ backfilled=True,
+ )
@trace
async def _get_missing_events_for_pdu(
@@ -715,7 +753,7 @@ class FederationEventHandler:
@trace
async def _process_pulled_events(
- self, origin: str, events: Iterable[EventBase], backfilled: bool
+ self, origin: str, events: Collection[EventBase], backfilled: bool
) -> None:
"""Process a batch of events we have pulled from a remote server
@@ -730,6 +768,15 @@ class FederationEventHandler:
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
"""
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str([event.event_id for event in events]),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(events)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
logger.debug(
"processing pulled backfilled=%s events=%s",
backfilled,
@@ -753,6 +800,7 @@ class FederationEventHandler:
await self._process_pulled_event(origin, ev, backfilled=backfilled)
@trace
+ @tag_args
async def _process_pulled_event(
self, origin: str, event: EventBase, backfilled: bool
) -> None:
@@ -854,6 +902,7 @@ class FederationEventHandler:
else:
raise
+ @trace
async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase
) -> EventContext:
@@ -970,6 +1019,8 @@ class FederationEventHandler:
event, state_ids_before_event=state_map, partial_state=partial_state
)
+ @trace
+ @tag_args
async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
@@ -1009,10 +1060,10 @@ class FederationEventHandler:
logger.debug("Fetching %i events from cache/store", len(desired_events))
have_events = await self._store.have_seen_events(room_id, desired_events)
- missing_desired_events = desired_events - have_events
+ missing_desired_event_ids = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
- len(missing_desired_events),
+ len(missing_desired_event_ids),
len(have_events),
)
@@ -1024,13 +1075,30 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
- missing_auth_events = set(auth_event_ids) - have_events
- missing_auth_events.difference_update(
- await self._store.have_seen_events(room_id, missing_auth_events)
+ missing_auth_event_ids = set(auth_event_ids) - have_events
+ missing_auth_event_ids.difference_update(
+ await self._store.have_seen_events(room_id, missing_auth_event_ids)
)
- logger.debug("We are also missing %i auth events", len(missing_auth_events))
+ logger.debug("We are also missing %i auth events", len(missing_auth_event_ids))
- missing_events = missing_desired_events | missing_auth_events
+ missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
+
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+ str(missing_auth_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+ str(len(missing_auth_event_ids)),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+ str(missing_desired_event_ids),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+ str(len(missing_desired_event_ids)),
+ )
# Making an individual request for each of 1000s of events has a lot of
# overhead. On the other hand, we don't really want to fetch all of the events
@@ -1041,13 +1109,13 @@ class FederationEventHandler:
#
# TODO: might it be better to have an API which lets us do an aggregate event
# request
- if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+ if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
logger.debug("Requesting complete state from remote")
await self._get_state_and_persist(destination, room_id, event_id)
else:
- logger.debug("Fetching %i events from remote", len(missing_events))
+ logger.debug("Fetching %i events from remote", len(missing_event_ids))
await self._get_events_and_persist(
- destination=destination, room_id=room_id, event_ids=missing_events
+ destination=destination, room_id=room_id, event_ids=missing_event_ids
)
# We now need to fill out the state map, which involves fetching the
@@ -1104,6 +1172,14 @@ class FederationEventHandler:
event_id,
failed_to_fetch,
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+ str(failed_to_fetch),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+ str(len(failed_to_fetch)),
+ )
if remote_event.is_state() and remote_event.rejected_reason is None:
state_map[
@@ -1112,6 +1188,8 @@ class FederationEventHandler:
return state_map
+ @trace
+ @tag_args
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
) -> None:
@@ -1133,6 +1211,7 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=(event_id,)
)
+ @trace
async def _process_received_pdu(
self,
origin: str,
@@ -1283,6 +1362,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
+ @trace
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@@ -1314,7 +1394,7 @@ class FederationEventHandler:
logger.debug("_handle_marker_event: received %s", marker_event)
insertion_event_id = marker_event.content.get(
- EventContentFields.MSC2716_MARKER_INSERTION
+ EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
)
if insertion_event_id is None:
@@ -1414,6 +1494,8 @@ class FederationEventHandler:
return event_from_response
+ @trace
+ @tag_args
async def _get_events_and_persist(
self, destination: str, room_id: str, event_ids: Collection[str]
) -> None:
@@ -1459,6 +1541,7 @@ class FederationEventHandler:
logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
await self._auth_and_persist_outliers(room_id, events)
+ @trace
async def _auth_and_persist_outliers(
self, room_id: str, events: Iterable[EventBase]
) -> None:
@@ -1477,6 +1560,16 @@ class FederationEventHandler:
"""
event_map = {event.event_id: event for event in events}
+ event_ids = event_map.keys()
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+
# filter out any events we have already seen. This might happen because
# the events were eagerly pushed to us (eg, during a room join), or because
# another thread has raced against us since we decided to request the event.
@@ -1593,6 +1686,7 @@ class FederationEventHandler:
backfilled=True,
)
+ @trace
async def _check_event_auth(
self, origin: Optional[str], event: EventBase, context: EventContext
) -> None:
@@ -1631,6 +1725,14 @@ class FederationEventHandler:
claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
origin, event
)
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+ str([ev.event_id for ev in claimed_auth_events]),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+ str(len(claimed_auth_events)),
+ )
# ... and check that the event passes auth at those auth events.
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
@@ -1728,6 +1830,7 @@ class FederationEventHandler:
)
context.rejected = RejectedReason.AUTH_ERROR
+ @trace
async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess:
return
@@ -1935,6 +2038,8 @@ class FederationEventHandler:
# instead we raise an AuthError, which will make the caller ignore it.
raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
+ @trace
+ @tag_args
async def _get_remote_auth_chain_for_event(
self, destination: str, room_id: str, event_id: str
) -> None:
@@ -1963,6 +2068,7 @@ class FederationEventHandler:
await self._auth_and_persist_outliers(room_id, remote_auth_events)
+ @trace
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None:
@@ -2071,8 +2177,17 @@ class FederationEventHandler:
self._message_handler.maybe_schedule_expiry(event)
if not backfilled: # Never notify for backfilled events
- for event in events:
- await self._notify_persisted_event(event, max_stream_token)
+ with start_active_span("notify_persisted_events"):
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "event_ids",
+ str([ev.event_id for ev in events]),
+ )
+ set_tag(
+ SynapseTags.RESULT_PREFIX + "event_ids.length",
+ str(len(events)),
+ )
+ for event in events:
+ await self._notify_persisted_event(event, max_stream_token)
return max_stream_token.stream
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 6484e47e5..860c82c11 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -309,18 +309,18 @@ class InitialSyncHandler:
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
- user_id = requester.user.to_string()
-
(
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
room_id,
- user_id,
+ requester,
allow_departed_users=True,
)
is_peeking = member_event_id is None
+ user_id = requester.user.to_string()
+
if membership == Membership.JOIN:
result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e6fc93dd2..08ee7caf2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -104,7 +104,7 @@ class MessageHandler:
async def get_room_data(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
event_type: str,
state_key: str,
@@ -112,7 +112,7 @@ class MessageHandler:
"""Get data from a room.
Args:
- user_id
+ requester: The user who did the request.
room_id
event_type
state_key
@@ -125,7 +125,7 @@ class MessageHandler:
membership,
membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership == Membership.JOIN:
@@ -161,11 +161,10 @@ class MessageHandler:
async def get_state_events(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None,
- is_guest: bool = False,
) -> List[dict]:
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@@ -174,14 +173,13 @@ class MessageHandler:
visible.
Args:
- user_id: The user requesting state events.
+ requester: The user requesting state events.
room_id: The room ID to get all state events from.
state_filter: The state filter used to fetch state from the database.
at_token: the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
- is_guest: whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
Raises:
@@ -191,6 +189,7 @@ class MessageHandler:
members of this room.
"""
state_filter = state_filter or StateFilter.all()
+ user_id = requester.user.to_string()
if at_token:
last_event_id = (
@@ -223,7 +222,7 @@ class MessageHandler:
membership,
membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership == Membership.JOIN:
@@ -317,12 +316,11 @@ class MessageHandler:
Returns:
A dict of user_id to profile info
"""
- user_id = requester.user.to_string()
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
membership, _ = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if membership != Membership.JOIN:
raise SynapseError(
@@ -331,12 +329,19 @@ class MessageHandler:
msg="Getting joined members while not being a current member of the room is forbidden.",
)
- users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
+ users_with_profile = (
+ await self._state_storage_controller.get_users_in_room_with_profiles(
+ room_id
+ )
+ )
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
# is a user in the room that the AS is "interested in"
- if requester.app_service and user_id not in users_with_profile:
+ if (
+ requester.app_service
+ and requester.user.to_string() not in users_with_profile
+ ):
for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid):
break
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index e1e34e3b1..74e944bce 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -464,7 +464,7 @@ class PaginationHandler:
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
if pagin_config.direction == "b":
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e7b3b5be5..76ea8129c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -29,7 +29,13 @@ from synapse.api.constants import (
JoinRules,
LoginType,
)
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ ConsentNotGivenError,
+ InvalidClientTokenError,
+ SynapseError,
+)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
@@ -185,10 +191,7 @@ class RegistrationHandler:
)
if guest_access_token:
user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if (
- not user_data.is_guest
- or UserID.from_string(user_data.user_id).localpart != localpart
- ):
+ if not user_data.is_guest or user_data.user.localpart != localpart:
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@@ -625,7 +628,7 @@ class RegistrationHandler:
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
if not service:
- raise AuthError(403, "Invalid application service token.")
+ raise InvalidClientTokenError()
if not service.is_interested_in_user(user_id):
raise SynapseError(
400,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 72d25df8c..28d7093f0 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -103,7 +103,7 @@ class RelationsHandler:
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7565dca66..16c120301 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -721,7 +721,7 @@ class RoomCreationHandler:
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
# Let the third party rules modify the room creation config if needed, or abort
# the room creation entirely with an exception.
@@ -1291,7 +1291,7 @@ class RoomContextHandler:
"""
user = requester.user
if use_admin_priviledge:
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 70dc69c80..65b9a655d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -179,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""Try and join a room that this server is not in
Args:
- requester
+ requester: The user making the request, according to the access token.
remote_room_hosts: List of servers that can be used to join via.
room_id: Room that we are trying to join
user: User who is trying to join
@@ -689,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
errcode=Codes.BAD_JSON,
)
- if "avatar_url" in content:
+ if "avatar_url" in content and content.get("avatar_url") is not None:
if not await self.profile_handler.check_avatar_size_and_mime_type(
content["avatar_url"],
):
@@ -744,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_requester_admin = True
else:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin:
if self.config.server.block_non_admin_invites:
@@ -868,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
bypass_spam_checker = True
else:
- bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+ bypass_spam_checker = await self.auth.is_server_admin(requester)
inviter = await self._get_inviter(target.to_string(), room_id)
if (
@@ -1410,7 +1410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ShadowBanError if the requester has been shadow-banned.
"""
if self.config.server.block_non_admin_invites:
- is_requester_admin = await self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -1693,7 +1693,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
check_complexity
and self.hs.config.server.limit_remote_rooms.admins_can_join
):
- check_complexity = not await self.auth.is_server_admin(user)
+ check_complexity = not await self.store.is_server_admin(user)
if check_complexity:
# Fetch the room complexity
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f28fc0238..15c00b9f2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,19 @@
# limitations under the License.
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ FrozenSet,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -89,7 +101,7 @@ class SyncConfig:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
prev_batch: StreamToken
- events: List[EventBase]
+ events: Sequence[EventBase]
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
@@ -507,10 +519,17 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
+ # FIXME(faster_joins): We use the partial state here as
+ # we don't want to block `/sync` on finishing a lazy join.
+ # Which should be fine once
+ # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+ # since we shouldn't reach here anymore?
+ # Note that we use the current state as a whitelist for filtering
+ # `recents`, so partial state is only a problem when a membership
+ # event turns up in `recents` but has not made it into the current
+ # state.
current_state_ids_map = (
- await self._state_storage_controller.get_current_state_ids(
- room_id
- )
+ await self.store.get_partial_current_state_ids(room_id)
)
current_state_ids = frozenset(current_state_ids_map.values())
@@ -579,7 +598,13 @@ class SyncHandler:
if any(e.is_state() for e in loaded_recents):
# FIXME(faster_joins): We use the partial state here as
# we don't want to block `/sync` on finishing a lazy join.
- # Is this the correct way of doing it?
+ # Which should be fine once
+ # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+ # since we shouldn't reach here anymore?
+ # Note that we use the current state as a whitelist for filtering
+ # `loaded_recents`, so partial state is only a problem when a
+ # membership event turns up in `loaded_recents` but has not made it
+ # into the current state.
current_state_ids_map = (
await self.store.get_partial_current_state_ids(room_id)
)
@@ -627,7 +652,10 @@ class SyncHandler:
)
async def get_state_after_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
@@ -635,9 +663,14 @@ class SyncHandler:
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
"""
state_ids = await self._state_storage_controller.get_state_ids_for_event(
- event_id, state_filter=state_filter or StateFilter.all()
+ event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -660,6 +693,7 @@ class SyncHandler:
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
@@ -667,6 +701,9 @@ class SyncHandler:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the last event in the room before `stream_position` and
+ `state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
@@ -678,7 +715,9 @@ class SyncHandler:
if last_event_id:
state = await self.get_state_after_event(
- last_event_id, state_filter=state_filter or StateFilter.all()
+ last_event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
else:
@@ -852,16 +891,26 @@ class SyncHandler:
now_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
- """Works out the difference in state between the start of the timeline
- and the previous sync.
+ """Works out the difference in state between the end of the previous sync and
+ the start of the timeline.
Args:
room_id:
batch: The timeline batch for the room that will be sent to the user.
sync_config:
- since_token: Token of the end of the previous batch. May be None.
+ since_token: Token of the end of the previous batch. May be `None`.
now_token: Token of the end of the current batch.
full_state: Whether to force returning the full state.
+ `lazy_load_members` still applies when `full_state` is `True`.
+
+ Returns:
+ The state to return in the sync response for the room.
+
+ Clients will overlay this onto the state at the end of the previous sync to
+ arrive at the state at the start of the timeline.
+
+ Clients will then overlay state events in the timeline to arrive at the
+ state at the end of the timeline, in preparation for the next sync.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -869,8 +918,17 @@ class SyncHandler:
# TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"):
+ # The memberships needed for events in the timeline.
+ # Only calculated when `lazy_load_members` is on.
+ members_to_fetch: Optional[Set[str]] = None
- members_to_fetch = None
+ # A dictionary mapping user IDs to the first event in the timeline sent by
+ # them. Only calculated when `lazy_load_members` is on.
+ first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
+
+ # The contribution to the room state from state events in the timeline.
+ # Only contains the last event for any given state key.
+ timeline_state: StateMap[str]
lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = (
@@ -881,10 +939,23 @@ class SyncHandler:
# We only request state for the members needed to display the
# timeline:
- members_to_fetch = {
- event.sender # FIXME: we also care about invite targets etc.
- for event in batch.events
- }
+ timeline_state = {}
+
+ members_to_fetch = set()
+ first_event_by_sender_map = {}
+ for event in batch.events:
+ # Build the map from user IDs to the first timeline event they sent.
+ if event.sender not in first_event_by_sender_map:
+ first_event_by_sender_map[event.sender] = event
+
+ # We need the event's sender, unless their membership was in a
+ # previous timeline event.
+ if (EventTypes.Member, event.sender) not in timeline_state:
+ members_to_fetch.add(event.sender)
+ # FIXME: we also care about invite targets etc.
+
+ if event.is_state():
+ timeline_state[(event.type, event.state_key)] = event.event_id
if full_state:
# always make sure we LL ourselves so we know we're in the room
@@ -894,55 +965,80 @@ class SyncHandler:
members_to_fetch.add(sync_config.user.to_string())
state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
- else:
- state_filter = StateFilter.all()
- timeline_state = {
- (event.type, event.state_key): event.event_id
- for event in batch.events
- if event.is_state()
- }
+ # We are happy to use partial state to compute the `/sync` response.
+ # Since partial state may not include the lazy-loaded memberships we
+ # require, we fix up the state response afterwards with memberships from
+ # auth events.
+ await_full_state = False
+ else:
+ timeline_state = {
+ (event.type, event.state_key): event.event_id
+ for event in batch.events
+ if event.is_state()
+ }
+
+ state_filter = StateFilter.all()
+ await_full_state = True
+
+ # Now calculate the state to return in the sync response for the room.
+ # This is more or less the change in state between the end of the previous
+ # sync's timeline and the start of the current sync's timeline.
+ # See the docstring above for details.
+ state_ids: StateMap[str]
if full_state:
if batch:
- current_state_ids = (
+ state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ batch.events[-1].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
- state_ids = (
+ state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ batch.events[0].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
- current_state_ids = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ state_at_timeline_end = await self.get_state_at(
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
- state_ids = current_state_ids
+ state_at_timeline_start = state_at_timeline_end
state_ids = _calculate_state(
timeline_contains=timeline_state,
- timeline_start=state_ids,
- previous={},
- current=current_state_ids,
+ timeline_start=state_at_timeline_start,
+ timeline_end=state_at_timeline_end,
+ previous_timeline_end={},
lazy_load_members=lazy_load_members,
)
elif batch.limited:
if batch:
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ batch.events[0].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
# for now, we disable LL for gappy syncs - see
@@ -964,28 +1060,35 @@ class SyncHandler:
# is indeed the case.
assert since_token is not None
state_at_previous_sync = await self.get_state_at(
- room_id, stream_position=since_token, state_filter=state_filter
+ room_id,
+ stream_position=since_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
if batch:
- current_state_ids = (
+ state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter
+ batch.events[-1].event_id,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
)
else:
- # Its not clear how we get here, but empirically we do
- # (#5407). Logging has been added elsewhere to try and
- # figure out where this state comes from.
- current_state_ids = await self.get_state_at(
- room_id, stream_position=now_token, state_filter=state_filter
+ # We can get here if the user has ignored the senders of all
+ # the recent events.
+ state_at_timeline_end = await self.get_state_at(
+ room_id,
+ stream_position=now_token,
+ state_filter=state_filter,
+ await_full_state=await_full_state,
)
state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
- previous=state_at_previous_sync,
- current=current_state_ids,
+ timeline_end=state_at_timeline_end,
+ previous_timeline_end=state_at_previous_sync,
# we have to include LL members in case LL initial sync missed them
lazy_load_members=lazy_load_members,
)
@@ -1008,8 +1111,30 @@ class SyncHandler:
(EventTypes.Member, member)
for member in members_to_fetch
),
+ await_full_state=False,
)
+ # If we only have partial state for the room, `state_ids` may be missing the
+ # memberships we wanted. We attempt to find some by digging through the auth
+ # events of timeline events.
+ if lazy_load_members and await self.store.is_partial_state_room(room_id):
+ assert members_to_fetch is not None
+ assert first_event_by_sender_map is not None
+
+ additional_state_ids = (
+ await self._find_missing_partial_state_memberships(
+ room_id, members_to_fetch, first_event_by_sender_map, state_ids
+ )
+ )
+ state_ids = {**state_ids, **additional_state_ids}
+
+ # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+ # the memberships of all event senders in the timeline. This is because we
+ # may not have sent the memberships in a previous sync.
+
+ # When `include_redundant_members` is on, we send all the lazy-loaded
+ # memberships of event senders. Otherwise we make an effort to limit the set
+ # of memberships we send to those that we have not already sent to this client.
if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1050,6 +1175,99 @@ class SyncHandler:
)
}
+ async def _find_missing_partial_state_memberships(
+ self,
+ room_id: str,
+ members_to_fetch: Collection[str],
+ events_with_membership_auth: Mapping[str, EventBase],
+ found_state_ids: StateMap[str],
+ ) -> StateMap[str]:
+ """Finds missing memberships from a set of auth events and returns them as a
+ state map.
+
+ Args:
+ room_id: The partial state room to find the remaining memberships for.
+ members_to_fetch: The memberships to find.
+ events_with_membership_auth: A mapping from user IDs to events whose auth
+ events are known to contain their membership.
+ found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+ memberships that have been previously found. Entries in
+ `members_to_fetch` that have a membership in `found_state_ids` are
+ ignored.
+
+ Returns:
+ A dict from ("m.room.member", state_key) -> state_event_id, containing the
+ memberships missing from `found_state_ids`.
+
+ Raises:
+ KeyError: if `events_with_membership_auth` does not have an entry for a
+ missing membership. Memberships in `found_state_ids` do not need an
+ entry in `events_with_membership_auth`.
+ """
+ additional_state_ids: MutableStateMap[str] = {}
+
+ # Tracks the missing members for logging purposes.
+ missing_members = set()
+
+ # Identify memberships missing from `found_state_ids` and pick out the auth
+ # events in which to look for them.
+ auth_event_ids: Set[str] = set()
+ for member in members_to_fetch:
+ if (EventTypes.Member, member) in found_state_ids:
+ continue
+
+ missing_members.add(member)
+ event_with_membership_auth = events_with_membership_auth[member]
+ auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+ auth_events = await self.store.get_events(auth_event_ids)
+
+ # Run through the missing memberships once more, picking out the memberships
+ # from the pile of auth events we have just fetched.
+ for member in members_to_fetch:
+ if (EventTypes.Member, member) in found_state_ids:
+ continue
+
+ event_with_membership_auth = events_with_membership_auth[member]
+
+ # Dig through the auth events to find the desired membership.
+ for auth_event_id in event_with_membership_auth.auth_event_ids():
+ # We only store events once we have all their auth events,
+ # so the auth event must be in the pile we have just
+ # fetched.
+ auth_event = auth_events[auth_event_id]
+
+ if (
+ auth_event.type == EventTypes.Member
+ and auth_event.state_key == member
+ ):
+ missing_members.remove(member)
+ additional_state_ids[
+ (EventTypes.Member, member)
+ ] = auth_event.event_id
+ break
+
+ if missing_members:
+ # There really shouldn't be any missing memberships now. Either:
+ # * we couldn't find an auth event, which shouldn't happen because we do
+ # not persist events with persisting their auth events first, or
+ # * the set of auth events did not contain a membership we wanted, which
+ # means our caller didn't compute the events in `members_to_fetch`
+ # correctly, or we somehow accepted an event whose auth events were
+ # dodgy.
+ logger.error(
+ "Failed to find memberships for %s in partial state room "
+ "%s in the auth events of %s.",
+ missing_members,
+ room_id,
+ [
+ events_with_membership_auth[member].event_id
+ for member in missing_members
+ ],
+ )
+
+ return additional_state_ids
+
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> NotifCounts:
@@ -1694,7 +1912,11 @@ class SyncHandler:
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
- old_state_ids = await self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+ )
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
@@ -1720,7 +1942,13 @@ class SyncHandler:
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
- old_state_ids = await self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Member, user_id)]
+ ),
+ )
old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None
)
@@ -2215,8 +2443,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
def _calculate_state(
timeline_contains: StateMap[str],
timeline_start: StateMap[str],
- previous: StateMap[str],
- current: StateMap[str],
+ timeline_end: StateMap[str],
+ previous_timeline_end: StateMap[str],
lazy_load_members: bool,
) -> StateMap[str]:
"""Works out what state to include in a sync response.
@@ -2224,45 +2452,50 @@ def _calculate_state(
Args:
timeline_contains: state in the timeline
timeline_start: state at the start of the timeline
- previous: state at the end of the previous sync (or empty dict
+ timeline_end: state at the end of the timeline
+ previous_timeline_end: state at the end of the previous sync (or empty dict
if this is an initial sync)
- current: state at the end of the timeline
lazy_load_members: whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
"""
- event_id_to_key = {
- e: key
- for key, e in itertools.chain(
+ event_id_to_state_key = {
+ event_id: state_key
+ for state_key, event_id in itertools.chain(
timeline_contains.items(),
- previous.items(),
timeline_start.items(),
- current.items(),
+ timeline_end.items(),
+ previous_timeline_end.items(),
)
}
- c_ids = set(current.values())
- ts_ids = set(timeline_start.values())
- p_ids = set(previous.values())
- tc_ids = set(timeline_contains.values())
+ timeline_end_ids = set(timeline_end.values())
+ timeline_start_ids = set(timeline_start.values())
+ previous_timeline_end_ids = set(previous_timeline_end.values())
+ timeline_contains_ids = set(timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
# as we may not have sent them to the client before. We find these membership
# events by filtering them out of timeline_start, which has already been filtered
# to only include membership events for the senders in the timeline.
- # In practice, we can do this by removing them from the p_ids list,
- # which is the list of relevant state we know we have already sent to the client.
+ # In practice, we can do this by removing them from the previous_timeline_end_ids
+ # list, which is the list of relevant state we know we have already sent to the
+ # client.
# see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
- p_ids.difference_update(
+ previous_timeline_end_ids.difference_update(
e for t, e in timeline_start.items() if t[0] == EventTypes.Member
)
- state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+ state_ids = (
+ (timeline_end_ids | timeline_start_ids)
+ - previous_timeline_end_ids
+ - timeline_contains_ids
+ )
- return {event_id_to_key[e]: e for e in state_ids}
+ return {event_id_to_state_key[e]: e for e in state_ids}
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 27aa0d312..bcac3372a 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None:
target_user_id = target_user.to_string()
- auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
- if target_user_id != auth_user_id:
+ if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned:
@@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
- await self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has started typing in %s", target_user_id, room_id)
@@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str
) -> None:
target_user_id = target_user.to_string()
- auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver")
- if target_user_id != auth_user_id:
+ if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned:
@@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
- await self.auth.check_user_in_room(room_id, target_user_id)
+ await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has stopped typing in %s", target_user_id, room_id)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0..26aaabfb3 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,9 +23,12 @@ from typing import (
Optional,
Sequence,
Tuple,
+ Type,
+ TypeVar,
overload,
)
+from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from twisted.web.server import Request
@@ -694,6 +697,28 @@ def parse_json_object_from_request(
return content
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+ request: Request, model_type: Type[Model]
+) -> Model:
+ """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+ validate using the given pydantic model.
+
+ Raises:
+ SynapseError if the request body couldn't be decoded as JSON or
+ if it wasn't a JSON object.
+ """
+ content = parse_json_object_from_request(request, allow_empty_body=False)
+ try:
+ instance = model_type.parse_obj(content)
+ except ValidationError as e:
+ raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
+
+ return instance
+
+
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent = []
for k in required:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78..1155f3f61 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -226,7 +226,7 @@ class SynapseRequest(Request):
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
- if self._requester.user.to_string() != authenticated_entity:
+ if requester != authenticated_entity:
return requester, authenticated_entity
return requester, None
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fa3f76c27..482316a1f 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import (
Any,
Callable,
Collection,
+ ContextManager,
Dict,
Generator,
Iterable,
@@ -309,6 +310,19 @@ class SynapseTags:
# The name of the external cache
CACHE_NAME = "cache.name"
+ # Used to tag function arguments
+ #
+ # Tag a named arg. The name of the argument should be appended to this prefix.
+ FUNC_ARG_PREFIX = "ARG."
+ # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+ FUNC_ARGS = "args"
+ # Tag keyword args
+ FUNC_KWARGS = "kwargs"
+
+ # Some intermediate result that's interesting to the function. The label for
+ # the result should be appended to this prefix.
+ RESULT_PREFIX = "RESULT."
+
class SynapseBaggage:
FORCE_TRACING = "synapse-force-tracing"
@@ -823,75 +837,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def _custom_sync_async_decorator(
+ func: Callable[P, R],
+ wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
+ """
+ Decorates a function that is sync or async (coroutines), or that returns a Twisted
+ `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+ Example usage:
+ ```py
+ # Decorator to time the function and log it out
+ def duration(func: Callable[P, R]) -> Callable[P, R]:
+ @contextlib.contextmanager
+ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+ start_ts = time.time()
+ try:
+ yield
+ finally:
+ end_ts = time.time()
+ duration = end_ts - start_ts
+ logger.info("%s took %s seconds", func.__name__, duration)
+ return _custom_sync_async_decorator(func, _wrapping_logic)
+ ```
+
+ Args:
+ func: The function to be decorated
+ wrapping_logic: The business logic of your custom decorator.
+ This should be a ContextManager so you are able to run your logic
+ before/after the function as desired.
+ """
+
+ if inspect.iscoroutinefunction(func):
+
+ @wraps(func)
+ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ with wrapping_logic(func, *args, **kwargs):
+ return await func(*args, **kwargs) # type: ignore[misc]
+
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ scope = wrapping_logic(func, *args, **kwargs)
+ scope.__enter__()
+
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
+
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _wrapper # type: ignore[return-value]
+
+
+def trace_with_opname(
+ opname: str,
+ *,
+ tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to trace a function with a custom opname.
-
See the module's doc string for usage examples.
-
"""
- def decorator(func: Callable[P, R]) -> Callable[P, R]:
- if opentracing is None:
- return func # type: ignore[unreachable]
+ # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+ @contextlib.contextmanager # type: ignore[arg-type]
+ def _wrapping_logic(
+ func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+ ) -> Generator[None, None, None]:
+ with start_active_span(opname, tracer=tracer):
+ yield
- if inspect.iscoroutinefunction(func):
+ def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+ if not opentracing:
+ return func
- @wraps(func)
- async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- with start_active_span(opname):
- return await func(*args, **kwargs) # type: ignore[misc]
+ return _custom_sync_async_decorator(func, _wrapping_logic)
- else:
- # The other case here handles both sync functions and those
- # decorated with inlineDeferred.
- @wraps(func)
- def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- scope = start_active_span(opname)
- scope.__enter__()
-
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
-
- def call_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- def err_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- result.addCallbacks(call_back, err_back)
-
- else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
-
- scope.__exit__(None, None, None)
-
- return result
-
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
-
- return _trace_inner # type: ignore[return-value]
-
- return decorator
+ return _decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
-
Sets the operation name to that of the function's name.
-
See the module's doc string for usage examples.
"""
@@ -900,7 +956,7 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
- Tags all of the args to the active span.
+ Decorator to tag all of the args to the active span.
Args:
func: `func` is assumed to be a method taking a `self` parameter, or a
@@ -911,22 +967,25 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
if not opentracing:
return func
- @wraps(func)
- def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+ @contextlib.contextmanager # type: ignore[arg-type]
+ def _wrapping_logic(
+ func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+ ) -> Generator[None, None, None]:
argspec = inspect.getfullargspec(func)
# We use `[1:]` to skip the `self` object reference and `start=1` to
# make the index line up with `argspec.args`.
#
- # FIXME: We could update this handle any type of function by ignoring the
+ # FIXME: We could update this to handle any type of function by ignoring the
# first argument only if it's named `self` or `cls`. This isn't fool-proof
# but handles the idiomatic cases.
for i, arg in enumerate(args[1:], start=1): # type: ignore[index]
- set_tag("ARG_" + argspec.args[i], str(arg))
- set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index]
- set_tag("kwargs", str(kwargs))
- return func(*args, **kwargs)
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+ set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index]
+ set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
+ yield
- return _tag_args_inner
+ return _custom_sync_async_decorator(func, _wrapping_logic)
@contextlib.contextmanager
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6c0cc5a6c..440205e80 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -14,128 +14,235 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-from typing import Any, Dict, List
+"""
+Push rules is the system used to determine which events trigger a push (and a
+bump in notification counts).
-from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+This consists of a list of "push rules" for each user, where a push rule is a
+pair of "conditions" and "actions". When a user receives an event Synapse
+iterates over the list of push rules until it finds one where all the conditions
+match the event, at which point "actions" describe the outcome (e.g. notify,
+highlight, etc).
+
+Push rules are split up into 5 different "kinds" (aka "priority classes"), which
+are run in order:
+ 1. Override — highest priority rules, e.g. always ignore notices
+ 2. Content — content specific rules, e.g. @ notifications
+ 3. Room — per room rules, e.g. enable/disable notifications for all messages
+ in a room
+ 4. Sender — per sender rules, e.g. never notify for messages from a given
+ user
+ 5. Underride — the lowest priority "default" rules, e.g. notify for every
+ message.
+
+The set of "base rules" are the list of rules that every user has by default. A
+user can modify their copy of the push rules in one of three ways:
+
+ 1. Adding a new push rule of a certain kind
+ 2. Changing the actions of a base rule
+ 3. Enabling/disabling a base rule.
+
+The base rules are split into whether they come before or after a particular
+kind, so the order of push rule evaluation would be: base rules for before
+"override" kind, user defined "override" rules, base rules after "override"
+kind, etc, etc.
+"""
+
+import itertools
+import logging
+from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union
+
+import attr
+
+from synapse.config.experimental import ExperimentalConfig
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+
+logger = logging.getLogger(__name__)
-def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Combine the list of rules set by the user with the default push rules
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class PushRule:
+ """A push rule
- Args:
- rawrules: The rules the user has modified or set.
-
- Returns:
- A new list with the rules set by the user combined with the defaults.
+ Attributes:
+ rule_id: a unique ID for this rule
+ priority_class: what "kind" of push rule this is (see
+ `PRIORITY_CLASS_MAP` for mapping between int and kind)
+ conditions: the sequence of conditions that all need to match
+ actions: the actions to apply if all conditions are met
+ default: is this a base rule?
+ default_enabled: is this enabled by default?
"""
- ruleslist = []
- # Grab the base rules that the user has modified.
- # The modified base rules have a priority_class of -1.
- modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
+ rule_id: str
+ priority_class: int
+ conditions: Sequence[Mapping[str, str]]
+ actions: Sequence[Union[str, Mapping]]
+ default: bool = False
+ default_enabled: bool = True
- # Remove the modified base rules from the list, They'll be added back
- # in the default positions in the list.
- rawrules = [r for r in rawrules if r["priority_class"] >= 0]
- # shove the server default rules for each kind onto the end of each
- current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class PushRules:
+ """A collection of push rules for an account.
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ Can be iterated over, producing push rules in priority order.
+ """
+
+ # A mapping from rule ID to push rule that overrides a base rule. These will
+ # be returned instead of the base rule.
+ overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict)
+
+ # The following stores the custom push rules at each priority class.
+ #
+ # We keep these separate (rather than combining into one big list) to avoid
+ # copying the base rules around all the time.
+ override: List[PushRule] = attr.Factory(list)
+ content: List[PushRule] = attr.Factory(list)
+ room: List[PushRule] = attr.Factory(list)
+ sender: List[PushRule] = attr.Factory(list)
+ underride: List[PushRule] = attr.Factory(list)
+
+ def __iter__(self) -> Iterator[PushRule]:
+ # When iterating over the push rules we need to return the base rules
+ # interspersed at the correct spots.
+ for rule in itertools.chain(
+ BASE_PREPEND_OVERRIDE_RULES,
+ self.override,
+ BASE_APPEND_OVERRIDE_RULES,
+ self.content,
+ BASE_APPEND_CONTENT_RULES,
+ self.room,
+ self.sender,
+ self.underride,
+ BASE_APPEND_UNDERRIDE_RULES,
+ ):
+ # Check if a base rule has been overriden by a custom rule. If so
+ # return that instead.
+ override_rule = self.overriden_base_rules.get(rule.rule_id)
+ if override_rule:
+ yield override_rule
+ else:
+ yield rule
+
+ def __len__(self) -> int:
+ # The length is mostly used by caches to get a sense of "size" / amount
+ # of memory this object is using, so we only count the number of custom
+ # rules.
+ return (
+ len(self.overriden_base_rules)
+ + len(self.override)
+ + len(self.content)
+ + len(self.room)
+ + len(self.sender)
+ + len(self.underride)
)
- )
- for r in rawrules:
- if r["priority_class"] < current_prio_class:
- while r["priority_class"] < current_prio_class:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
- ruleslist.append(r)
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class FilteredPushRules:
+ """A wrapper around `PushRules` that filters out disabled experimental push
+ rules, and includes the "enabled" state for each rule when iterated over.
+ """
- while current_prio_class > 0:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ push_rules: PushRules
+ enabled_map: Dict[str, bool]
+ experimental_config: ExperimentalConfig
+
+ def __iter__(self) -> Iterator[Tuple[PushRule, bool]]:
+ for rule in self.push_rules:
+ if not _is_experimental_rule_enabled(
+ rule.rule_id, self.experimental_config
+ ):
+ continue
+
+ enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled)
+
+ yield rule, enabled
+
+ def __len__(self) -> int:
+ return len(self.push_rules)
+
+
+DEFAULT_EMPTY_PUSH_RULES = PushRules()
+
+
+def compile_push_rules(rawrules: List[PushRule]) -> PushRules:
+ """Given a set of custom push rules return a `PushRules` instance (which
+ includes the base rules).
+ """
+
+ if not rawrules:
+ # Fast path to avoid allocating empty lists when there are no custom
+ # rules for the user.
+ return DEFAULT_EMPTY_PUSH_RULES
+
+ rules = PushRules()
+
+ for rule in rawrules:
+ # We need to decide which bucket each custom push rule goes into.
+
+ # If it has the same ID as a base rule then it overrides that...
+ overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id)
+ if overriden_base_rule:
+ rules.overriden_base_rules[rule.rule_id] = attr.evolve(
+ overriden_base_rule, actions=rule.actions
)
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
+ continue
+
+ # ... otherwise it gets added to the appropriate priority class bucket
+ collection: List[PushRule]
+ if rule.priority_class == 5:
+ collection = rules.override
+ elif rule.priority_class == 4:
+ collection = rules.content
+ elif rule.priority_class == 3:
+ collection = rules.room
+ elif rule.priority_class == 2:
+ collection = rules.sender
+ elif rule.priority_class == 1:
+ collection = rules.underride
+ elif rule.priority_class <= 0:
+ logger.info(
+ "Got rule with priority class less than zero, but doesn't override a base rule: %s",
+ rule,
)
+ continue
+ else:
+ # We log and continue here so as not to break event sending
+ logger.error("Unknown priority class: %", rule.priority_class)
+ continue
- return ruleslist
-
-
-def make_base_append_rules(
- kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
-) -> List[Dict[str, Any]]:
- rules = []
-
- if kind == "override":
- rules = BASE_APPEND_OVERRIDE_RULES
- elif kind == "underride":
- rules = BASE_APPEND_UNDERRIDE_RULES
- elif kind == "content":
- rules = BASE_APPEND_CONTENT_RULES
-
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
+ collection.append(rule)
return rules
-def make_base_prepend_rules(
- kind: str,
- modified_base_rules: Dict[str, Dict[str, Any]],
-) -> List[Dict[str, Any]]:
- rules = []
-
- if kind == "override":
- rules = BASE_PREPEND_OVERRIDE_RULES
-
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
-
- return rules
+def _is_experimental_rule_enabled(
+ rule_id: str, experimental_config: ExperimentalConfig
+) -> bool:
+ """Used by `FilteredPushRules` to filter out experimental rules when they
+ have not been enabled.
+ """
+ if (
+ rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
+ and not experimental_config.msc3786_enabled
+ ):
+ return False
+ if (
+ rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+ and not experimental_config.msc3772_enabled
+ ):
+ return False
+ return True
-# We have to annotate these types, otherwise mypy infers them as
-# `List[Dict[str, Sequence[Collection[str]]]]`.
-BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/content/.m.rule.contains_user_name",
- "conditions": [
+BASE_APPEND_CONTENT_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["content"],
+ rule_id="global/content/.m.rule.contains_user_name",
+ conditions=[
{
"kind": "event_match",
"key": "content.body",
@@ -143,29 +250,33 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
"pattern_type": "user_localpart",
}
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
- }
+ )
]
-BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.master",
- "enabled": False,
- "conditions": [],
- "actions": ["dont_notify"],
- }
+BASE_PREPEND_OVERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.master",
+ default_enabled=False,
+ conditions=[],
+ actions=["dont_notify"],
+ )
]
-BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.suppress_notices",
- "conditions": [
+BASE_APPEND_OVERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.suppress_notices",
+ conditions=[
{
"kind": "event_match",
"key": "content.msgtype",
@@ -173,13 +284,15 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_suppress_notices",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
- {
- "rule_id": "global/override/.m.rule.invite_for_me",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.invite_for_me",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -195,21 +308,23 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
# Match the requester's MXID.
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
- {
- "rule_id": "global/override/.m.rule.member_event",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.member_event",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -217,24 +332,28 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_member",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
- {
- "rule_id": "global/override/.m.rule.contains_display_name",
- "conditions": [{"kind": "contains_display_name"}],
- "actions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.contains_display_name",
+ conditions=[{"kind": "contains_display_name"}],
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
- },
- {
- "rule_id": "global/override/.m.rule.roomnotif",
- "conditions": [
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.roomnotif",
+ conditions=[
{
"kind": "event_match",
"key": "content.body",
@@ -247,11 +366,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_roomnotif_pl",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.tombstone",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": True}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.tombstone",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -265,11 +386,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_tombstone_statekey",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.reaction",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": True}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.m.rule.reaction",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -277,14 +400,16 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_reaction",
}
],
- "actions": ["dont_notify"],
- },
+ actions=["dont_notify"],
+ ),
# XXX: This is an experimental rule that is only enabled if msc3786_enabled
# is enabled, if it is not the rule gets filtered out in _load_rules() in
# PushRulesWorkerStore
- {
- "rule_id": "global/override/.org.matrix.msc3786.rule.room.server_acl",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["override"],
+ rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -298,15 +423,17 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_room_server_acl_state_key",
},
],
- "actions": [],
- },
+ actions=[],
+ ),
]
-BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/underride/.m.rule.call",
- "conditions": [
+BASE_APPEND_UNDERRIDE_RULES = [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.call",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -314,17 +441,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_call",
}
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "ring"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
- {
- "rule_id": "global/underride/.m.rule.room_one_to_one",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.room_one_to_one",
+ conditions=[
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
@@ -333,17 +462,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_message",
},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
+ ),
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.encrypted_room_one_to_one",
+ conditions=[
{"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
{
"kind": "event_match",
@@ -352,15 +483,17 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_encrypted",
},
],
- "actions": [
+ actions=[
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight", "value": False},
],
- },
- {
- "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
- "conditions": [
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.org.matrix.msc3772.thread_reply",
+ conditions=[
{
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.thread",
@@ -368,11 +501,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"sender_type": "user_id",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.m.rule.message",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.message",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -380,13 +515,15 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_message",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted",
- "conditions": [
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.m.rule.encrypted",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -394,11 +531,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_encrypted",
}
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.im.vector.jitsi",
- "conditions": [
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
+ PushRule(
+ default=True,
+ priority_class=PRIORITY_CLASS_MAP["underride"],
+ rule_id="global/underride/.im.vector.jitsi",
+ conditions=[
{
"kind": "event_match",
"key": "type",
@@ -418,29 +557,27 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
"_cache_key": "_is_state_event",
},
],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
+ actions=["notify", {"set_tweak": "highlight", "value": False}],
+ ),
]
BASE_RULE_IDS = set()
+BASE_RULES_BY_ID: Dict[str, PushRule] = {}
+
for r in BASE_APPEND_CONTENT_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["content"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_PREPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_APPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
for r in BASE_APPEND_UNDERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
+ BASE_RULE_IDS.add(r.rule_id)
+ BASE_RULES_BY_ID[r.rule_id] = r
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 713dcf695..ccd512be5 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,7 +15,18 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
@@ -30,6 +41,7 @@ from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
+from .baserules import FilteredPushRules, PushRule
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
@@ -112,7 +124,7 @@ class BulkPushRuleEvaluator:
async def _get_rules_for_event(
self,
event: EventBase,
- ) -> Dict[str, List[Dict[str, Any]]]:
+ ) -> Dict[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
@@ -186,7 +198,7 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations(
- self, event: EventBase, rules: Iterable[Dict[str, Any]]
+ self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
@@ -216,12 +228,11 @@ class BulkPushRuleEvaluator:
# Pre-filter to figure out which relation types are interesting.
rel_types = set()
- for rule in rules:
- # Skip disabled rules.
- if "enabled" in rule and not rule["enabled"]:
+ for rule, enabled in rules:
+ if not enabled:
continue
- for condition in rule["conditions"]:
+ for condition in rule.conditions:
if condition["kind"] != "org.matrix.msc3772.relation_match":
continue
@@ -254,7 +265,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event)
- actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+ actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
@@ -317,15 +328,13 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
- for rule in rules:
- if "enabled" in rule and not rule["enabled"]:
+ for rule, enabled in rules:
+ if not enabled:
continue
- matches = evaluator.check_conditions(
- rule["conditions"], uid, display_name
- )
+ matches = evaluator.check_conditions(rule.conditions, uid, display_name)
if matches:
- actions = [x for x in rule["actions"] if x != "dont_notify"]
+ actions = [x for x in rule.actions if x != "dont_notify"]
if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 5117ef685..73618d923 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -18,16 +18,15 @@ from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.types import UserID
+from .baserules import FilteredPushRules, PushRule
+
def format_push_rules_for_user(
- user: UserID, ruleslist: List
+ user: UserID, ruleslist: FilteredPushRules
) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
- # We're going to be mutating this a lot, so do a deep copy
- ruleslist = copy.deepcopy(ruleslist)
-
rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
"global": {},
"device": {},
@@ -35,11 +34,30 @@ def format_push_rules_for_user(
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
- for r in ruleslist:
- template_name = _priority_class_to_template_name(r["priority_class"])
+ for r, enabled in ruleslist:
+ template_name = _priority_class_to_template_name(r.priority_class)
+
+ rulearray = rules["global"][template_name]
+
+ template_rule = _rule_to_template(r)
+ if not template_rule:
+ continue
+
+ rulearray.append(template_rule)
+
+ template_rule["enabled"] = enabled
+
+ if "conditions" not in template_rule:
+ # Not all formatted rules have explicit conditions, e.g. "room"
+ # rules omit them as they can be derived from the kind and rule ID.
+ #
+ # If the formatted rule has no conditions then we can skip the
+ # formatting of conditions.
+ continue
# Remove internal stuff.
- for c in r["conditions"]:
+ template_rule["conditions"] = copy.deepcopy(template_rule["conditions"])
+ for c in template_rule["conditions"]:
c.pop("_cache_key", None)
pattern_type = c.pop("pattern_type", None)
@@ -52,16 +70,6 @@ def format_push_rules_for_user(
if sender_type == "user_id":
c["sender"] = user.to_string()
- rulearray = rules["global"][template_name]
-
- template_rule = _rule_to_template(r)
- if template_rule:
- if "enabled" in r:
- template_rule["enabled"] = r["enabled"]
- else:
- template_rule["enabled"] = True
- rulearray.append(template_rule)
-
return rules
@@ -71,24 +79,24 @@ def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
return d
-def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- unscoped_rule_id = None
- if "rule_id" in rule:
- unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
+def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
+ templaterule: Dict[str, Any]
- template_name = _priority_class_to_template_name(rule["priority_class"])
+ unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
+
+ template_name = _priority_class_to_template_name(rule.priority_class)
if template_name in ["override", "underride"]:
- templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+ templaterule = {"conditions": rule.conditions, "actions": rule.actions}
elif template_name in ["sender", "room"]:
- templaterule = {"actions": rule["actions"]}
- unscoped_rule_id = rule["conditions"][0]["pattern"]
+ templaterule = {"actions": rule.actions}
+ unscoped_rule_id = rule.conditions[0]["pattern"]
elif template_name == "content":
- if len(rule["conditions"]) != 1:
+ if len(rule.conditions) != 1:
return None
- thecond = rule["conditions"][0]
+ thecond = rule.conditions[0]
if "pattern" not in thecond:
return None
- templaterule = {"actions": rule["actions"]}
+ templaterule = {"actions": rule.actions}
templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
@@ -97,8 +105,8 @@ def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
- if "default" in rule:
- templaterule["default"] = rule["default"]
+ if rule.default:
+ templaterule["default"] = True
return templaterule
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2e8a017ad..3c5632cd9 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,18 @@
import logging
import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
+from typing import (
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Pattern,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
from matrix_common.regex import glob_to_regex, to_word_pattern
@@ -32,14 +43,14 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(
- ev: EventBase, condition: Dict[str, Any], room_member_count: int
+ ev: EventBase, condition: Mapping[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(
ev: EventBase,
- condition: Dict[str, Any],
+ condition: Mapping[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
@@ -54,7 +65,7 @@ def _sender_notification_permission(
return sender_power_level >= room_notif_level
-def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
+def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -137,7 +148,7 @@ class PushRuleEvaluatorForEvent:
self._condition_cache: Dict[str, bool] = {}
def check_conditions(
- self, conditions: List[dict], uid: str, display_name: Optional[str]
+ self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's conditions/user ID/display name match the event.
@@ -169,7 +180,7 @@ class PushRuleEvaluatorForEvent:
return True
def matches(
- self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
+ self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's condition/user ID/display name match the event.
@@ -204,7 +215,7 @@ class PushRuleEvaluatorForEvent:
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True
- def _event_match(self, condition: dict, user_id: str) -> bool:
+ def _event_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "event_match" push rule condition.
@@ -269,7 +280,7 @@ class PushRuleEvaluatorForEvent:
return bool(r.search(body))
- def _relation_match(self, condition: dict, user_id: str) -> bool:
+ def _relation_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "relation_match" push rule condition.
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
index b751359bd..bd4f7cea9 100644
--- a/synapse/res/templates/account_previously_renewed.html
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -1 +1,12 @@
-Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
+
+
+
+
+ Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
+ Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
\ No newline at end of file
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index e8c0f52f0..57b319f37 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1,12 @@
-Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
+
+
+
+
+ Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
+ Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+
+
\ No newline at end of file
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
index cc4ab07e0..71f2215b7 100644
--- a/synapse/res/templates/add_threepid.html
+++ b/synapse/res/templates/add_threepid.html
@@ -1,9 +1,14 @@
-
+
+
+
+
+
+
+ Request to add an email address to your Matrix account
+
A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:
-
{{ link }}
-
If this was not you, you can safely ignore this email. Thank you.
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
index 441d11c84..bd627ee9c 100644
--- a/synapse/res/templates/add_threepid_failure.html
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -1,8 +1,13 @@
-
-
+
+
+
+
+
+
+ Request failed
+
-The request failed for the following reason: {{ failure_reason }}.
-
-No changes have been made to your account.
+ The request failed for the following reason: {{ failure_reason }}.
+ No changes have been made to your account.
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
index fbd6e4018..49170c138 100644
--- a/synapse/res/templates/add_threepid_success.html
+++ b/synapse/res/templates/add_threepid_success.html
@@ -1,6 +1,12 @@
-
-
+
+
+
+
+
+
+ Your email has now been validated
+
-Your email has now been validated, please return to your client. You may now close this window.
+ Your email has now been validated, please return to your client. You may now close this window.
-
+
\ No newline at end of file
diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html
index baf463314..2d6ac44a0 100644
--- a/synapse/res/templates/auth_success.html
+++ b/synapse/res/templates/auth_success.html
@@ -1,8 +1,8 @@
Success!
-
+
+
diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html
index 16730a527..20e831ff4 100644
--- a/synapse/res/templates/registration.html
+++ b/synapse/res/templates/registration.html
@@ -1,4 +1,9 @@
-
+
+
+ Registration
+
+
+
You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:
diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html
index 2833d79c3..a6ed22bc9 100644
--- a/synapse/res/templates/registration_failure.html
+++ b/synapse/res/templates/registration_failure.html
@@ -1,5 +1,8 @@
-
-
+
+
+
+
+
Validation failed for the following reason: {{ failure_reason }}.
diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html
index fbd6e4018..d51d5549d 100644
--- a/synapse/res/templates/registration_success.html
+++ b/synapse/res/templates/registration_success.html
@@ -1,5 +1,9 @@
-
-
+
+
+ Your email has now been validated
+
+
+
Your email has now been validated, please return to your client. You may now close this window.
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
index 4577ce170..59a98f564 100644
--- a/synapse/res/templates/registration_token.html
+++ b/synapse/res/templates/registration_token.html
@@ -1,8 +1,8 @@
-
+
Authentication
-
+
+
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index c3e4deed9..075f801ce 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -3,8 +3,8 @@
SSO account deactivated
-
-
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index cf72df0a2..2d1db386e 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -3,7 +3,8 @@
Create your account
-
+
+
diff --git a/synapse/static/client/register/index.html b/synapse/static/client/register/index.html
index 140653574..27bbd76f5 100644
--- a/synapse/static/client/register/index.html
+++ b/synapse/static/client/register/index.html
@@ -2,7 +2,8 @@
Registration
-
+
+
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab4..dad3731b9 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+ SynapseTags,
+ active_span,
+ set_tag,
+ start_active_span_follows_from,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue.append(end_item)
# also add our active opentracing span to the item so that we get a link back
- span = opentracing.active_span()
+ span = active_span()
if span:
end_item.parent_opentracing_span_contexts.append(span.context)
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
res = await make_deferred_yieldable(end_item.deferred.observe())
# add another opentracing span which links to the persist trace.
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
f"Found an unexpected task type in event persistence queue: {task}"
)
- @opentracing.trace
+ @trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ event_ids: List[str] = []
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
+ event_ids.append(event.event_id)
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
self.main_store.get_room_max_token(),
)
- @opentracing.trace
+ @trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0d480f101..f9ffd0e29 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,7 +29,8 @@ from typing import (
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -228,10 +229,12 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids}
@trace
+ @tag_args
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -240,6 +243,9 @@ class StateStorageController:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -248,8 +254,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
+ # Full state is not required if the state filter is restrictive enough.
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -292,7 +302,10 @@ class StateStorageController:
@trace
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -300,6 +313,9 @@ class StateStorageController:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -309,7 +325,9 @@ class StateStorageController:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -332,6 +350,7 @@ class StateStorageController:
)
@trace
+ @tag_args
async def get_state_group_for_events(
self,
event_ids: Collection[str],
@@ -473,6 +492,7 @@ class StateStorageController:
prev_stream_id, max_stream_id
)
+ @trace
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
@@ -506,3 +526,15 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_users_in_room_with_profiles(
+ self, room_id: str
+ ) -> Dict[str, ProfileInfo]:
+ """
+ Get the current users in the room with their profiles.
+ If the room is currently partial-stated, this will block until the room has
+ full state.
+ """
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b647..c836078da 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
+ @trace
+ @tag_args
async def get_auth_chain_ids(
self,
room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
+ @trace
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse()
return event_results
+ @trace
+ @tag_args
async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given event as a prev event
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ @trace
async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
await self.db_pool.simple_upsert(
table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 161aad0f8..eabf9c973 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,7 +74,17 @@ receipt.
"""
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
@@ -154,7 +164,9 @@ class NotifCounts:
highlight_count: int = 0
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+ actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -227,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
- for a given user in a given room after the given read receipt.
+ for a given user in a given room after their latest read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
@@ -238,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A dict containing the counts mentioned earlier in this docstring,
- respectively under the keys "notify_count", "highlight_count" and
- "unread_count".
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -255,6 +266,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
) -> NotifCounts:
+ # Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
txn,
user_id,
@@ -266,13 +278,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
),
)
- stream_ordering = None
if result:
_, stream_ordering = result
- if stream_ordering is None:
- # Either last_read_event_id is None, or it's an event we don't have (e.g.
- # because it's been purged), in which case retrieve the stream ordering for
+ else:
+ # If the user has no receipts in the room, retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -289,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
def _get_unread_counts_by_pos_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ receipt_stream_ordering: int,
) -> NotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ receipt_stream_ordering: The stream ordering of the user's latest
+ receipt in the room. If there are no receipts, the stream ordering
+ of the user's join event.
+
+ Returns
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
counts = NotifCounts()
@@ -320,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
OR last_receipt_stream_ordering = ?
)
""",
- (room_id, user_id, stream_ordering, stream_ordering),
+ (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
@@ -338,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND stream_ordering > ?
AND highlight = 1
"""
- txn.execute(sql, (user_id, room_id, stream_ordering))
+ txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
# Finally we need to count push actions that aren't included in the
- # summary returned above, e.g. recent events that haven't been
- # summarised yet, or the summary is empty due to a recent read receipt.
- stream_ordering = max(stream_ordering, summary_stream_ordering)
+ # summary returned above. This might be due to recent events that haven't
+ # been summarised yet or the summary is out of date due to a recent read
+ # receipt.
+ start_unread_stream_ordering = max(
+ receipt_stream_ordering, summary_stream_ordering
+ )
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering
+ txn, room_id, user_id, start_unread_stream_ordering
)
counts.notify_count += notify_count
@@ -733,7 +762,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def add_push_actions_to_staging(
self,
event_id: str,
- user_id_actions: Dict[str, List[Union[dict, str]]],
+ user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -750,7 +779,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
- user_id: str, actions: List[Union[dict, str]]
+ user_id: str, actions: Collection[Union[Mapping, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
@@ -1151,8 +1180,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: The database transaction.
old_rotate_stream_ordering: The previous maximum event stream ordering.
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
-
- Returns whether the archiving process has caught up or not.
"""
# Calculate the new counts that should be upserted into event_push_summary
@@ -1238,9 +1265,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(rotate_to_stream_ordering,),
)
- async def _remove_old_push_actions_that_have_rotated(
- self,
- ) -> None:
+ async def _remove_old_push_actions_that_have_rotated(self) -> None:
"""Clear out old push actions that have been summarised."""
# We want to clear out anything that is older than a day that *has* already
@@ -1397,7 +1422,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
]
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
for action in actions:
if not isinstance(action, dict):
continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a4..a4010ee28 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ @trace
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e9ff6cfb3..8a7cdb024 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.logging.opentracing import start_active_span, tag_args, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
+ @trace
+ @tag_args
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
"""
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
- events_to_fetch = event_ids
- while events_to_fetch:
- row_map = await self._enqueue_events(events_to_fetch)
+ async def _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch: Collection[str],
+ ) -> Collection[str]:
+ """
+ Fetch all of the given event_ids and return any associated redaction event_ids
+ that we still need to fetch in the next iteration.
+ """
+ row_map = await self._enqueue_events(event_ids_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids: Set[str] = set()
- for event_id in events_to_fetch:
+ for event_id in event_ids_to_fetch:
row = row_map.get(event_id)
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
- if events_to_fetch:
- logger.debug("Also fetching redaction events %s", events_to_fetch)
+ event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+ return event_ids_to_fetch
+
+ # Grab the initial list of events requested
+ event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids
+ )
+ # Then go and recursively find all of the associated redactions
+ with start_active_span("recursively fetching redactions"):
+ while event_ids_to_fetch:
+ logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+ event_ids_to_fetch = (
+ await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch
+ )
+ )
# build a map from event_id to EventBase
event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
+ @trace
+ @tag_args
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
@@ -2200,3 +2224,63 @@ class EventsWorkerStore(SQLBaseStore):
(room_id,),
)
return [row[0] for row in txn]
+
+ def mark_event_rejected_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ rejection_reason: Optional[str],
+ ) -> None:
+ """Mark an event that was previously accepted as rejected, or vice versa
+
+ This can happen, for example, when resyncing state during a faster join.
+
+ Args:
+ txn:
+ event_id: ID of event to update
+ rejection_reason: reason it has been rejected, or None if it is now accepted
+ """
+ if rejection_reason is None:
+ logger.info(
+ "Marking previously-processed event %s as accepted",
+ event_id,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ "rejections",
+ keyvalues={"event_id": event_id},
+ )
+ else:
+ logger.info(
+ "Marking previously-processed event %s as rejected(%s)",
+ event_id,
+ rejection_reason,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": event_id},
+ values={
+ "reason": rejection_reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
+ self.db_pool.simple_update_txn(
+ txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"rejection_reason": rejection_reason},
+ )
+
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+ # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+ # call '_send_invalidation_to_replication', but we actually need the other
+ # end to call _invalidate_local_get_event_cache() rather than (just)
+ # _get_event_cache.invalidate().
+ #
+ # One solution might be to (somehow) get the workers to call
+ # _invalidate_caches_for_event() (though that will invalidate more than
+ # strictly necessary).
+ #
+ # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16..255620f99 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
+ ruleslist = [
+ PushRule(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=db_to_json(rawrule["conditions"]),
+ actions=db_to_json(rawrule["actions"]),
+ )
+ for rawrule in rawrules
]
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ push_rules = compile_push_rules(ruleslist)
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
-
- return rules
+ return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -216,11 +198,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -234,11 +216,13 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
@@ -345,8 +329,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -817,7 +801,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -827,21 +811,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +849,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules:
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f22..124c70ad3 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: The receipt types to fetch.
Returns:
- The latest receipt, if one exists.
+ The event ID and stream ordering of the latest receipt, if one exists.
"""
clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7..7fb9c801d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
"""
user_id: str
+ token_id: int
is_guest: bool = False
shadow_banned: bool = False
- token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0f1f0d11e..b7d4baa6b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 93ff4816c..827c1f1ef 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -283,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -1212,6 +1215,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index f70705a0a..0b10af0e5 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -430,6 +430,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
updatevalues={"state_group": state_group},
)
+ # the event may now be rejected where it was not before, or vice versa,
+ # in which case we need to update the rejected flags.
+ if bool(context.rejected) != (event.rejected_reason is not None):
+ self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
self.db_pool.simple_delete_one_txn(
txn,
table="partial_state_events",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index af3bab2c1..0004d955b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,15 +539,6 @@ class StateFilter:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""
-
- # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
- # there may be circumstances in which we return a piece of state that, once we
- # resync the state, we discover is invalid. For example: if it turns out that
- # the sender of a piece of state wasn't actually in the room, then clearly that
- # state shouldn't have been returned.
- # We should at least add some tests around this to see what happens.
- # https://github.com/matrix-org/synapse/issues/13006
-
# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f..b4bf49dac 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialStateEventsTracker.await_full_state")
async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialCurrentStateTracker.await_full_state")
async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room.
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 6394cc39a..f678b52cb 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,6 +18,8 @@ import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
+from prometheus_client.core import Counter
+
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
@@ -27,6 +29,8 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock
if typing.TYPE_CHECKING:
@@ -35,6 +39,32 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter("synapse_rate_limit_sleep", "")
+rate_limit_reject_counter = Counter("synapse_rate_limit_reject", "")
+queue_wait_timer = Histogram(
+ "synapse_rate_limit_queue_wait_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ "+Inf",
+ ),
+)
+
+
class FederationRateLimiter:
def __init__(self, clock: Clock, config: FederationRatelimitSettings):
def new_limiter() -> "_PerHostRatelimiter":
@@ -44,6 +74,27 @@ class FederationRateLimiter:
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
+ # We track the number of affected hosts per time-period so we can
+ # differentiate one really noisy homeserver from a general
+ # ratelimit tuning problem across the federation.
+ LaterGauge(
+ "synapse_rate_limit_sleep_affected_hosts",
+ "Number of hosts that had requests put to sleep",
+ [],
+ lambda: sum(
+ ratelimiter.should_sleep() for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+ LaterGauge(
+ "synapse_rate_limit_reject_affected_hosts",
+ "Number of hosts that had requests rejected",
+ [],
+ lambda: sum(
+ ratelimiter.should_reject()
+ for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
@@ -59,7 +110,7 @@ class FederationRateLimiter:
Returns:
context manager which returns a deferred.
"""
- return self.ratelimiters[host].ratelimit()
+ return self.ratelimiters[host].ratelimit(host)
class _PerHostRatelimiter:
@@ -94,19 +145,42 @@ class _PerHostRatelimiter:
self.request_times: List[int] = []
@contextlib.contextmanager
- def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+ def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
+ self.host = host
+
request_id = object()
- ret = self._on_enter(request_id)
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
try:
yield ret
finally:
self._on_exit(request_id)
+ def should_reject(self) -> bool:
+ """
+ Whether to reject the request if we already have too many queued up
+ (either sleeping or in the ready queue).
+ """
+ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+ return queue_size > self.reject_limit
+
+ def should_sleep(self) -> bool:
+ """
+ Whether to sleep the request if we already have too many requests coming
+ through within the window.
+ """
+ return len(self.request_times) > self.sleep_limit
+
+ async def _on_enter_with_tracing(self, request_id: object) -> None:
+ with start_active_span("ratelimit wait"), queue_wait_timer.time():
+ await self._on_enter(request_id)
+
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
@@ -117,8 +191,9 @@ class _PerHostRatelimiter:
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
- queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
- if queue_size > self.reject_limit:
+ if self.should_reject():
+ logger.debug("Ratelimiter(%s): rejecting request", self.host)
+ rate_limit_reject_counter.inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
)
@@ -130,7 +205,8 @@ class _PerHostRatelimiter:
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
- "Ratelimiter: queueing request (queue now %i items)",
+ "Ratelimiter(%s): queueing request (queue now %i items)",
+ self.host,
len(self.ready_request_queue),
)
@@ -139,19 +215,28 @@ class _PerHostRatelimiter:
return defer.succeed(None)
logger.debug(
- "Ratelimit [%s]: len(self.request_times)=%d",
+ "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+ self.host,
id(request_id),
len(self.request_times),
)
- if len(self.request_times) > self.sleep_limit:
- logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+ if self.should_sleep():
+ logger.debug(
+ "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+ self.host,
+ id(request_id),
+ self.sleep_sec,
+ )
+ rate_limit_sleep_counter.inc()
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
- logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+ )
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -161,7 +246,9 @@ class _PerHostRatelimiter:
ret_defer = queue_request()
def on_start(r: object) -> object:
- logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+ )
self.current_processing.add(request_id)
return r
@@ -183,7 +270,7 @@ class _PerHostRatelimiter:
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id: object) -> None:
- logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+ logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 741c16575..ffcdf996f 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -73,8 +73,8 @@ async def filter_events_for_client(
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
- always_include_ids: set of event ids to specifically
- include (unless sender is ignored)
+ always_include_ids: set of event ids to specifically include, if present
+ in events (unless sender is ignored)
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dfcfaf79b..e0f363555 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
serialized = macaroon.serialize()
user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
- self.assertEqual(user_id, user_info.user_id)
+ self.assertEqual(user_id, user_info.user.to_string())
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index ffc3012a8..685a9a6d5 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
)
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
load_legacy_presence_router(hs)
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index ff9f2e8ed..7b9b71152 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes
+from synapse.push.baserules import PushRule
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
@@ -130,12 +130,12 @@ class DeactivateAccountTestCase(HomeserverTestCase):
),
)
- def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+ def _is_custom_rule(self, push_rule: PushRule) -> bool:
"""
Default rules start with a dot: such as .m.rule and .im.vector.
This function returns true iff a rule is custom (not default).
"""
- return "/." not in push_rule["rule_id"]
+ return "/." not in push_rule.rule_id
def test_push_rules_deleted_upon_account_deactivation(self) -> None:
"""
@@ -157,22 +157,21 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
# Test the rule exists
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
# Check our rule made it
self.assertEqual(
push_rules,
[
- {
- "user_name": "@user:test",
- "rule_id": "personal.override.rule1",
- "priority_class": 5,
- "priority": 0,
- "conditions": [],
- "actions": [],
- "default": False,
- }
+ PushRule(
+ rule_id="personal.override.rule1",
+ priority_class=5,
+ conditions=[],
+ actions=[],
+ )
],
push_rules,
)
@@ -180,9 +179,11 @@ class DeactivateAccountTestCase(HomeserverTestCase):
# Request the deactivation of our account
self._deactivate_my_account()
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)
@@ -321,3 +322,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
),
)
+
+ def test_deactivate_account_needs_auth(self) -> None:
+ """
+ Tests that making a request to /deactivate with an empty body
+ succeeds in starting the user-interactive auth flow.
+ """
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {},
+ access_token=self.token,
+ )
+
+ self.assertEqual(req.code, 401, req)
+ self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4c62449c8..75934b170 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,7 +21,6 @@ from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
- load_legacy_password_auth_providers(hs)
-
- return hs
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 23f35d5bf..86b3d5197 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
-from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client
)
- load_legacy_spam_checkers(hs)
-
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index b4e1405ae..1d13ed1e8 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -14,7 +14,7 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -216,7 +216,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
# - trying to remote-join again.
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af133312..8adba29d7 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id: str, user_id: str) -> None:
- if user_id not in [u.to_string() for u in self.room_members]:
+ async def check_user_in_room(room_id: str, requester: Requester) -> None:
+ if requester.user.to_string() not in [
+ u.to_string() for u in self.room_members
+ ]:
raise AuthError(401, "User is not in the room")
return None
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 3b14c76d7..0917e478a 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -25,6 +25,8 @@ from synapse.logging.context import (
from synapse.logging.opentracing import (
start_active_span,
start_active_span_follows_from,
+ tag_args,
+ trace_with_opname,
)
from synapse.util import Clock
@@ -38,8 +40,12 @@ try:
except ImportError:
jaeger_client = None # type: ignore
+import logging
+
from tests.unittest import TestCase
+logger = logging.getLogger(__name__)
+
class LogContextScopeManagerTestCase(TestCase):
"""
@@ -194,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
self._reporter.get_spans(),
[scopes[1].span, scopes[2].span, scopes[0].span],
)
+
+ def test_trace_decorator_sync(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with sync functions
+ """
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_sync_func", tracer=self._tracer)
+ @tag_args
+ def fixture_sync_func() -> str:
+ return "foo"
+
+ result = fixture_sync_func()
+ self.assertEqual(result, "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_sync_func"],
+ )
+
+ def test_trace_decorator_deferred(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with functions that return deferreds
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
+ @tag_args
+ def fixture_deferred_func() -> "defer.Deferred[str]":
+ d1: defer.Deferred[str] = defer.Deferred()
+ d1.callback("foo")
+ return d1
+
+ result_d1 = fixture_deferred_func()
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(result_d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_deferred_func"],
+ )
+
+ def test_trace_decorator_async(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with async functions
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_async_func", tracer=self._tracer)
+ @tag_args
+ async def fixture_async_func() -> str:
+ return "foo"
+
+ d1 = defer.ensureDeferred(fixture_async_func())
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_async_func"],
+ )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 106159fa6..02cef6f87 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -30,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
class ModuleApiTestCase(HomeserverTestCase):
@@ -738,11 +737,6 @@ class ModuleApiTestCase(HomeserverTestCase):
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
- # Testing stream ID replication from the main to worker processes requires postgres
- # (due to needing `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -752,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533..ce53f808d 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
- ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+ ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
+ Enables Redis, providing a fake Redis server.
+
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
+
def setUp(self):
super().setUp()
# build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
- if self.hs.config.redis.redis_enabled:
- # Handle attempts to connect to fake redis server.
- self.reactor.add_tcp_client_callback(
- "localhost",
- 6379,
- self.connect_any_redis_attempts,
- )
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
+ )
- self.hs.get_replication_command_handler().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
- # Set up TCP replication between master and the new worker if we don't
- # have Redis support enabled.
- if not worker_hs.config.redis.redis_enabled:
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs,
- "client",
- "test",
- self.clock,
- repl_handler,
- )
- server = self.server_factory.buildProtocol(
- IPv4Address("TCP", "127.0.0.1", 0)
- )
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
- if worker_hs.config.redis.redis_enabled:
- worker_hs.get_replication_command_handler().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
- """
- A test case that enables Redis, providing a fake Redis server.
- """
-
- if not hiredis:
- skip = "Requires hiredis"
-
- if not USE_POSTGRES_FOR_TESTS:
- # Redis replication only takes place on Postgres
- skip = "Requires Postgres"
-
- def default_config(self) -> Dict[str, Any]:
- """
- Overrides the default config to enable Redis.
- Even if the test only uses make_worker_hs, the main process needs Redis
- enabled otherwise it won't create a Fake Redis server to listen on the
- Redis port and accept fake TCP connections.
- """
- base = super().default_config()
- base["redis"] = {"enabled": True}
- return base
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd..1e299d2d6 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069..541d39028 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
- # Event persister sharding requires postgres (due to needing
- # `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 06e74d5e5..a8f643683 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,7 +13,6 @@
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
+ "Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 8295ecf24..d507a3af8 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import Collection
from parameterized import parameterized
@@ -51,7 +50,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -64,7 +63,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# job_name invalid
@@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def _register_bg_update(self) -> None:
@@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 779f1bfac..d52aee8f9 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -58,7 +57,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(method, self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -76,7 +75,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -85,7 +84,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -98,13 +97,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -117,12 +116,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self) -> None:
"""
- Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or 200.
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -134,7 +133,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -179,7 +178,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -312,7 +311,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -331,7 +330,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -339,7 +338,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -348,12 +347,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -363,7 +362,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self) -> None:
@@ -438,7 +437,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -457,7 +456,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -465,7 +464,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -474,12 +473,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -489,7 +488,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self) -> None:
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 9bc6ce62c..8a4e5c3f7 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -81,16 +80,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_invalid_search_order(self) -> None:
"""
- Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
+ Testing that a invalid search order returns a 400
"""
channel = self.make_request(
@@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self) -> None:
"""
- Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative limit parameter returns a 400
"""
channel = self.make_request(
@@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self) -> None:
"""
- Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative from parameter returns a 400
"""
channel = self.make_request(
@@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -431,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("score", c)
self.assertIn("reason", c)
+ def test_count_correct_despite_table_deletions(self) -> None:
+ """
+ Tests that the count matches the number of rows, even if rows in joined tables
+ are missing.
+ """
+
+ # Delete rows from room_stats_state for one of our rooms.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_delete(
+ "room_stats_state", {"room_id": self.room_id1}, desc="_"
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ # The 'total' field is 10 because only 10 reports will actually
+ # be retrievable since we deleted the rows in the room_stats_state
+ # table.
+ self.assertEqual(channel.json_body["total"], 10)
+ # This is consistent with the number of rows actually returned.
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+
class EventReportDetailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -466,16 +472,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -484,11 +486,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -507,7 +505,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_invalid_report_id(self) -> None:
"""
- Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
+ Testing that an invalid `report_id` returns a 400.
"""
# `report_id` is negative
@@ -517,11 +515,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -535,11 +529,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -553,11 +543,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -566,7 +552,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_report_id_not_found(self) -> None:
"""
- Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
+ Testing that a not existing `report_id` returns a 404.
"""
channel = self.make_request(
@@ -575,11 +561,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c3927c273..4c7864c62 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -64,7 +63,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -117,7 +116,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
# invalid destination
@@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"The retry timing does not need to be reset for this destination.",
channel.json_body["error"],
@@ -561,7 +560,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -604,7 +603,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 92fd6c780..aadb31ca8 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from http import HTTPStatus
from parameterized import parameterized
@@ -60,7 +59,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -81,16 +80,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self) -> None:
"""
- Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a media that does not exist returns a 404
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -100,12 +95,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a media that is not a local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self) -> None:
@@ -188,10 +183,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -230,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -250,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for media that is not local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -269,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self) -> None:
@@ -282,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
@@ -302,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -319,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -337,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -354,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
@@ -612,10 +579,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -667,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
@@ -688,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self) -> None:
@@ -800,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
@@ -821,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self) -> None:
@@ -894,7 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -913,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -930,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -947,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 544daaa4c..8f8abc21c 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -13,7 +13,6 @@
# limitations under the License.
import random
import string
-from http import HTTPStatus
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_no_auth(self) -> None:
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self) -> None:
@@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self) -> None:
@@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self) -> None:
@@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self) -> None:
@@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
+ self.assertEqual(400, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self) -> None:
@@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.BAD_REQUEST,
+ 400,
channel.code,
msg=channel.json_body,
)
@@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self) -> None:
@@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self) -> None:
@@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
@@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self) -> None:
@@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self) -> None:
@@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self) -> None:
@@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self) -> None:
@@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self) -> None:
@@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
@@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self) -> None:
@@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self) -> None:
@@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self) -> None:
@@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self) -> None:
@@ -682,7 +589,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self) -> None:
@@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_list_no_auth(self) -> None:
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self) -> None:
@@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self) -> None:
@@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def _test_list_query_parameter(self, valid: str) -> None:
"""Helper used to test both valid=true and valid=false."""
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 6ea7858db..fd6da557c 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
@@ -68,7 +67,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -78,7 +77,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -98,7 +97,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -145,7 +144,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -163,7 +162,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -178,7 +177,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self) -> None:
@@ -319,7 +318,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -398,7 +397,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -494,7 +493,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -504,7 +503,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
channel = self.make_request(
@@ -556,7 +555,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -592,7 +591,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -610,7 +609,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -625,7 +624,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_delete_expired_status(self) -> None:
@@ -696,7 +695,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_delete_same_room_twice(self) -> None:
@@ -722,9 +721,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
- )
+ self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
self.assertEqual(
f"History purge already in progress for {self.room_id}",
@@ -858,7 +855,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -955,7 +952,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1546,7 +1543,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
+ _search_test(None, "", expected_http_code=400)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1636,6 +1633,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("history_visibility", channel.json_body)
self.assertIn("state_events", channel.json_body)
self.assertIn("room_type", channel.json_body)
+ self.assertIn("forgotten", channel.json_body)
self.assertEqual(room_id_1, channel.json_body["room_id"])
def test_single_room_devices(self) -> None:
@@ -1782,7 +1780,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# delete the rooms and get joined roomed membership
url = f"/_matrix/client/r0/rooms/{room_id}/joined_members"
channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1811,7 +1809,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -1821,7 +1819,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1836,12 +1834,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1851,7 +1849,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self) -> None:
@@ -1866,7 +1864,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1874,7 +1872,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self) -> None:
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return error 404.
"""
url = "/_synapse/admin/v1/join/!unknown:test"
@@ -1885,7 +1883,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(
"Can't join remote room because no servers that are in the room have been provided.",
channel.json_body["error"],
@@ -1893,7 +1891,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/join/invalidroom"
@@ -1904,7 +1902,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1952,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self) -> None:
@@ -2067,7 +2065,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self) -> None:
@@ -2243,11 +2241,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
+ # We expect this to fail with a 400 as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2277,7 +2275,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str) -> None:
- """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
+ """If the user is not a server admin, an error 403 is returned."""
channel = self.make_request(
method,
@@ -2286,12 +2284,12 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str) -> None:
- """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
+ """Check that invalid room names, return an error 400."""
channel = self.make_request(
method,
@@ -2300,7 +2298,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -2317,7 +2315,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# `block` is not set
@@ -2328,7 +2326,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no content is send
@@ -2338,7 +2336,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
def test_block_room(self) -> None:
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index bea3ac34d..a2f347f66 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -57,7 +56,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url)
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -72,7 +71,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -80,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
channel = self.make_request(
"POST",
self.url,
@@ -88,13 +87,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
channel = self.make_request(
"POST",
@@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -156,10 +155,66 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_avatar_url": "somthingwrong",
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_invalid_avatar_url(self) -> None:
+ """If avatar url in homeserver.yaml is invalid and
+ "check avatar size and mime type" is set, an error is returned.
+ TODO: Should be checked when reading the configuration."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+
+ self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_displayname_is_set_avatar_is_none(self) -> None:
+ """
+ Tests that sending a server notices is successfully,
+ if a display_name is set, avatar_url is `None` and
+ "check avatar size and mime type" is set.
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ self._check_invite_and_join_status(self.other_user, 1, 0)
+
def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
@@ -172,7 +227,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index baed27a81..b60f16b91 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -51,16 +50,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
"GET",
@@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index c2b54b1ef..1afd08270 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -79,7 +78,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -111,7 +110,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -119,7 +118,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self) -> None:
@@ -142,7 +141,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self) -> None:
@@ -198,7 +197,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self) -> None:
@@ -219,7 +218,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -229,28 +228,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -261,28 +260,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -298,7 +297,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
@@ -375,7 +374,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -466,7 +465,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -478,7 +477,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self) -> None:
@@ -591,7 +590,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -601,7 +600,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -611,7 +610,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid deactivated
@@ -621,7 +620,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -631,7 +630,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -641,7 +640,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -905,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+ """
+ Tests user device management-related Admin APIs.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Set up an Admin user to query the Admin API with.
+ self.admin_user_id = self.register_user("admin", "pass", admin=True)
+ self.admin_user_token = self.login("admin", "pass")
+
+ # Set up a test user to query the devices of.
+ self.other_user_device_id = "TESTDEVICEID"
+ self.other_user_device_display_name = "My Test Device"
+ self.other_user_client_ip = "1.2.3.4"
+ self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+ self.other_user_id = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login(
+ "user",
+ "pass",
+ device_id=self.other_user_device_id,
+ additional_request_fields={
+ "initial_device_display_name": self.other_user_device_display_name,
+ },
+ )
+
+ # Have the "other user" make a request so that the "last_seen_*" fields are
+ # populated in the tests below.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ client_ip=self.other_user_client_ip,
+ custom_headers=[
+ ("User-Agent", self.other_user_user_agent),
+ ],
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_list_user_devices(self) -> None:
+ """
+ Tests that a user's devices and attributes are listed correctly via the Admin API.
+ """
+ # Request all devices of "other user"
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Double-check we got the single device expected
+ user_devices = channel.json_body["devices"]
+ self.assertEqual(len(user_devices), 1)
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(user_devices[0])
+
+ # Request just a single device for "other user" by its ID
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+ f"{self.other_user_device_id}",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(channel.json_body)
+
+ def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+ # Check that all device expected attributes are present
+ self.assertEqual(response["user_id"], self.other_user_id)
+ self.assertEqual(response["device_id"], self.other_user_device_id)
+ self.assertEqual(response["display_name"], self.other_user_device_display_name)
+ self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+ self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+ self.assertIsInstance(response["last_seen_ts"], int)
+ self.assertGreater(response["last_seen_ts"], 0)
+
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -941,7 +1030,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -952,7 +1041,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -962,12 +1051,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that deactivation for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -976,7 +1065,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self) -> None:
@@ -991,18 +1080,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that deactivation for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self) -> None:
@@ -1220,7 +1309,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1230,12 +1319,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1244,7 +1333,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1259,7 +1348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1269,7 +1358,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1279,7 +1368,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1289,7 +1378,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1299,7 +1388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1311,7 +1400,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1320,7 +1409,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1330,7 +1419,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1339,7 +1428,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self) -> None:
@@ -1379,7 +1468,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1523,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1512,7 +1601,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1550,7 +1639,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1585,7 +1674,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1626,7 +1715,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1666,7 +1755,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@@ -2064,7 +2153,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
+ self.assertEqual(409, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2228,7 +2317,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2261,7 +2350,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2295,7 +2384,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2407,7 +2496,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2431,7 +2520,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2520,7 +2609,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2535,7 +2624,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -2678,7 +2767,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2693,12 +2782,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2707,12 +2796,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2722,7 +2811,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self) -> None:
@@ -2808,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2822,12 +2911,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2835,12 +2924,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
- """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2849,7 +2938,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self) -> None:
@@ -2970,7 +3059,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -2980,7 +3069,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -2990,7 +3079,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -3000,7 +3089,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -3393,7 +3482,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self) -> None:
@@ -3402,7 +3491,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self) -> None:
"""Test that sending event as a user works."""
@@ -3447,7 +3536,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3480,7 +3569,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self) -> None:
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3501,7 +3590,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3538,7 +3627,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
"com.example.test",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Login in as the user
@@ -3559,7 +3648,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
user=self.other_user,
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Logging in as the other user and joining a room should work, even
@@ -3594,7 +3683,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -3609,12 +3698,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined]
@@ -3623,7 +3712,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self) -> None:
@@ -3680,7 +3769,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3691,18 +3780,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self) -> None:
"""
@@ -3762,7 +3851,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3778,13 +3867,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3794,7 +3883,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3806,7 +3895,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3818,7 +3907,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self) -> None:
@@ -3833,7 +3922,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3844,7 +3933,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3855,7 +3944,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3866,7 +3955,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self) -> None:
@@ -3982,7 +4071,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Try to get information of a user without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -3995,7 +4084,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -4008,7 +4097,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
@@ -4021,7 +4110,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_success(self) -> None:
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index b5e7eecf8..30f12f1bf 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from http import HTTPStatus
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
if username == "allowed":
return True
raise SynapseError(
- HTTPStatus.BAD_REQUEST,
+ 400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@@ -67,10 +64,6 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 7ae926dc9..c1a7fb2f8 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 000000000..a9da00665
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from pydantic import ValidationError
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class EmailRequestTokenBodyTestCase(unittest.TestCase):
+ base_request = {
+ "client_secret": "hunter2",
+ "email": "alice@wonderland.com",
+ "send_attempt": 1,
+ }
+
+ def test_token_required_if_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ }
+ )
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": None,
+ }
+ )
+
+ def test_token_typechecked_when_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": 1337,
+ }
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c11335..9c8c1889d 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a2..c50f034b3 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/server.py b/tests/server.py
index 9689e6a0c..c447d5e4c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -61,6 +61,10 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.events.presence_router import load_legacy_presence_router
+from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
@@ -913,4 +917,14 @@ def setup_test_homeserver(
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
+ # Load any configured modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_spam_checkers(hs)
+ load_legacy_third_party_event_rules(hs)
+ load_legacy_presence_router(hs)
+ load_legacy_password_auth_providers(hs)
+
return hs
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index e07ae78fc..bf403045e 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -52,7 +55,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
@@ -251,7 +254,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
c["admin_contact"] = "mailto:user@test.com"
return c
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index ba40124c8..62fd4aeb2 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -135,7 +135,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(1, 1, 0)
# Delete old event push actions, this should not affect the (summarised) count.
+ #
+ # All event push actions are kept for 24 hours, so need to move forward
+ # in time.
+ self.pump(60 * 60 * 24)
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ # Double check that the event push actions have been cleared (i.e. that
+ # any results *must* come from the summary).
+ result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="event_push_actions",
+ keyvalues={"1": 1},
+ retcols=("event_id",),
+ desc="",
+ )
+ )
+ self.assertEqual(result, [])
_assert_counts(1, 1, 0)
_mark_read(last_event_id)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 240b02cb9..ceec69028 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -23,6 +23,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
+from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -157,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
+ def test_room_is_locally_forgotten(self):
+ """Test that when the last local user has forgotten a room it is known as forgotten."""
+ # join two local and one remote user
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_charlie.to_string(), "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # local users leave the room and the room is not forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave")
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # first user forgets the room, room is not forgotten
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # second (last local) user forgets the room and the room is forgotten
+ self.get_success(self.store.forget(self.u_bob, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ def test_join_locally_forgotten_room(self):
+ """Tests if a user joins a forgotten room the room is not forgotten anymore."""
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after leaving and forget the room, it is forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after rejoin the room is not forgotten anymore
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/unittest.py b/tests/unittest.py
index bec4a3d02..975b0a23a 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -677,14 +677,29 @@ class HomeserverTestCase(TestCase):
username: str,
password: str,
device_id: Optional[str] = None,
+ additional_request_fields: Optional[Dict[str, str]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
Log in a user, and get an access token. Requires the Login API be registered.
+
+ Args:
+ username: The localpart to assign to the new user.
+ password: The password to assign to the new user.
+ device_id: An optional device ID to assign to the new device created during
+ login.
+ additional_request_fields: A dictionary containing any additional /login
+ request fields and their values.
+ custom_headers: Custom HTTP headers and values to add to the /login request.
+
+ Returns:
+ The newly registered user's Matrix ID.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
body["device_id"] = device_id
+ if additional_request_fields:
+ body.update(additional_request_fields)
channel = self.make_request(
"POST",