diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh index 54ec3c8b0..b2859f752 100755 --- a/.ci/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive set -ex apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev +apt-get install -y \ + python3 python3-dev python3-pip python3-venv \ + libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev export LANG="C.UTF-8" diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..acb118c86 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +# TODO: incorporate this into pyproject.toml if flake8 supports it in the future. +# See https://github.com/PyCQA/flake8/issues/234 +[flake8] +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# E731: do not assign a lambda expression, use a def +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eb294f161..eee3633d5 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -7,7 +7,7 @@ on: # of things breaking (but only build one set of debs) pull_request: push: - branches: ["develop"] + branches: ["develop", "release-*"] # we do the full build on tags. tags: ["v*"] @@ -91,17 +91,7 @@ jobs: build-sdist: name: "Build pypi distribution files" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - run: pip install wheel - - run: | - python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: python-dist - path: dist/* + uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1" # if it's a tag, create a release and attach the artifacts to it attach-assets: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75ac1304b..e9e427732 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,12 +10,19 @@ concurrency: cancel-in-progress: true jobs: + check-sampleconfig: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - run: pip install -e . + - run: scripts-dev/generate_sample_config --check + lint: runs-on: ubuntu-latest strategy: matrix: toxenv: - - "check-sampleconfig" - "check_codestyle" - "check_isort" - "mypy" @@ -43,29 +50,15 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - uses: actions/setup-python@v2 - - run: pip install tox + - run: "pip install 'towncrier>=18.6.0rc1'" - run: scripts-dev/check-newsfragment env: PULL_REQUEST_NUMBER: ${{ github.event.number }} - lint-sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: "3.x" - - run: pip install wheel - - run: python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: Python Distributions - path: dist/* - # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile, lint-sdist] + needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig] runs-on: ubuntu-latest steps: - run: "true" @@ -397,7 +390,6 @@ jobs: - lint - lint-crlf - lint-newsfile - - lint-sdist - trial - trial-olddeps - sytest diff --git a/CHANGES.md b/CHANGES.md index 8d91a7921..a81d0a4b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,106 @@ +Synapse 1.54.0rc1 (2022-03-02) +============================== + +Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. +Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. + + +Features +-------- + +- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617)) +- Improve the preview that is produced when generating URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985)) +- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000)) +- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067)) +- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009)) +- Advertise Matrix 1.1 and 1.2 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020), ([\#12022](https://github.com/matrix-org/synapse/issues/12022)) +- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021)) +- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058)) +- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992)) +- Fix long-standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999)) +- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024)) +- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037)) +- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056)) +- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077)) +- Fix occasional `Unhandled error in Deferred` error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089)) +- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098)) +- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100)) +- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105)) +- Make a `POST` to `/rooms//receipt/m.read/` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. This reduces server load and load on the receiving device. ([\#11835](https://github.com/matrix-org/synapse/issues/11835)) + + +Updates to the Docker image +--------------------------- + +- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997)) +- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112)) + + +Improved Documentation +---------------------- + +- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599)) +- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003)) +- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004)) + + +Deprecations and Removals +------------------------- + +- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865)) +- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008)) +- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018)) +- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073)) + + +Internal Changes +---------------- + +- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808)) +- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900)) +- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972)) +- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974)) +- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984)) +- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991)) +- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994)) +- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996)) +- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039)) +- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011)) +- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012)) +- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013)) +- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015)) +- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016)) +- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019)) +- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025)) +- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030)) +- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070)) +- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069)) +- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041)) +- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051)) +- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059)) +- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060)) +- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063)) +- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094)) +- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068)) +- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088)) +- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092)) +- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099)) +- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106)) +- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109)) +- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111)) +- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119)) + + Synapse 1.53.0 (2022-02-22) =========================== -No significant changes. +No significant changes since 1.53.0rc1. Synapse 1.53.0rc1 (2022-02-15) @@ -11,7 +110,7 @@ Features -------- - Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966)) -- Remove account data (including client config, push rules and ignored users) upon user deactivation. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) +- Add a background database update to purge account data for deactivated users. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) - Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837)) - Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849)) - Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854)) @@ -92,7 +191,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. @@ -273,7 +372,7 @@ Bugfixes Synapse 1.50.0 (2022-01-18) =========================== -**This release contains a critical bug that may prevent clients from being able to connect. +**This release contains a critical bug that may prevent clients from being able to connect. As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).** diff --git a/MANIFEST.in b/MANIFEST.in index c24786c3b..76d14eb64 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -45,6 +45,7 @@ include book.toml include pyproject.toml recursive-include changelog.d * +include .flake8 prune .circleci prune .github prune .ci diff --git a/debian/changelog b/debian/changelog index 574930c08..df3db85b8 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium + + * New synapse release 1.54.0~rc1. + + -- Synapse Packaging team Wed, 02 Mar 2022 10:43:22 +0000 + matrix-synapse-py3 (1.53.0) stable; urgency=medium * New synapse release 1.53.0. diff --git a/docker/Dockerfile b/docker/Dockerfile index 306f75ae5..a8bb9b0e7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,10 +11,10 @@ # There is an optional PYTHON_VERSION build argument which sets the # version of python to build against: for example: # -# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.10 . # -ARG PYTHON_VERSION=3.8 +ARG PYTHON_VERSION=3.9 ### ### Stage 0: builder @@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py COPY ./docker/conf /conf -VOLUME ["/data"] - EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 1bbe23708..4076fcab6 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -126,7 +126,8 @@ Body parameters: [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) section `sso` and `oidc_providers`. - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` - in homeserver configuration. + in the homeserver configuration. Note that no error is raised if the provided + value is not in the homeserver configuration. - `external_id` - string, user ID in the external identity provider. - `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). diff --git a/docs/manhole.md b/docs/manhole.md index 715ed840f..a82fad0f0 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database: ```pycon >>> from twisted.internet import defer ->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) +>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 88b59bb09..ec810fd29 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -162,10 +162,38 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_displayname_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_displayname_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback returns `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + ## `is_3pid_allowed` _First introduced in Synapse v1.53.0_ @@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` - + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2eb9032f4..2b672b78f 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_ async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str] ``` -Called when receiving an event from a client or via federation. The module can return -either a `bool` to indicate whether the event must be rejected because of spam, or a `str` -to indicate the event must be rejected because of spam and to give a rejection reason to -forward to clients. +Called when receiving an event from a client or via federation. The callback must return +either: +- an error message string, to indicate the event must be rejected because of spam and + give a rejection reason to forward to clients; +- the boolean `True`, to indicate that the event is spammy, but not provide further details; or +- the booelan `False`, to indicate that the event is not considered spammy. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first @@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool ``` Called when a user is trying to join a room. The module must return a `bool` to indicate -whether the user can join the room. The user is represented by their Matrix user ID (e.g. +whether the user can join the room. Return `False` to prevent the user from joining the +room; otherwise return `True` to permit the joining. + +The user is represented by their Matrix user ID (e.g. `@alice:example.com`) and the room is represented by its Matrix ID (e.g. `!room:example.com`). The module is also given a boolean to indicate whether the user currently has a pending invite in the room. @@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (e.g. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent +the invitation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -80,7 +86,8 @@ async def user_may_send_3pid_invite( Called when processing an invitation using a third-party identifier (also called a 3PID, e.g. an email address or a phone number). The module must return a `bool` indicating -whether the inviter can invite the invitee to the given room. +whether the inviter can invite the invitee to the given room. Return `False` to prevent +the invitation; otherwise return `True` to permit it. The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the invitee is represented by its medium (e.g. "email") and its address @@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +Return `False` to prevent room creation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA Called when trying to associate an alias with an existing room. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed -to set the given alias. +to set the given alias. Return `False` to prevent the alias creation; otherwise return +`True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool Called when trying to publish a room to the homeserver's public rooms directory. The module must return a `bool` indicating whether the given user (represented by their -Matrix user ID) is allowed to publish the given room. +Matrix user ID) is allowed to publish the given room. Return `False` to prevent the +room from being published; otherwise return `True` to permit its publication. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool ``` Called when computing search results in the user directory. The module must return a -`bool` indicating whether the given user profile can appear in search results. The profile -is represented as a dictionary with the following keys: +`bool` indicating whether the given user should be excluded from user directory +searches. Return `True` to indicate that the user is spammy and exclude them from +search results; otherwise return `False`. + +The profile is represented as a dictionary with the following keys: * `user_id`: The Matrix ID for this user. * `display_name`: The user's display name. @@ -225,8 +238,9 @@ async def check_media_file_for_spam( ) -> bool ``` -Called when storing a local or remote file. The module must return a boolean indicating -whether the given file can be stored in the homeserver's media store. +Called when storing a local or remote file. The module must return a `bool` indicating +whether the given file should be excluded from the homeserver's media store. Return +`True` to prevent this file from being stored; otherwise return `False`. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index a3a17096a..09ac83810 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,62 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `on_profile_update` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_profile_update( + user_id: str, + new_profile: "synapse.module_api.ProfileInfo", + by_admin: bool, + deactivation: bool, +) -> None: +``` + +Called after updating a local user's profile. The update can be triggered either by the +user themselves or a server admin. The update can also be triggered by a user being +deactivated (in which case their display name is set to an empty string (`""`) and the +avatar URL is set to `None`). The module is passed the Matrix ID of the user whose profile +has been updated, their new profile, as well as a `by_admin` boolean that is `True` if the +update was triggered by a server admin (and `False` otherwise), and a `deactivated` +boolean that is `True` if the update is a result of the user being deactivated. + +Note that the `by_admin` boolean is also `True` if the profile change happens as a result +of the user logging in through Single Sign-On, or if a server admin updates their own +profile. + +Per-room profile changes do not trigger this callback to be called. Synapse administrators +wishing this callback to be called on every profile change are encouraged to disable +per-room profiles globally using the `allow_per_room_profiles` configuration setting in +Synapse's configuration file. +This callback is not called when registering a user, even when setting it through the +[`get_displayname_for_registration`](https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html#get_displayname_for_registration) +module callback. + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_user_deactivation_status_changed` + +_First introduced in Synapse v1.54.0_ + +```python +async def on_user_deactivation_status_changed( + user_id: str, deactivated: bool, by_admin: bool +) -> None: +``` + +Called after deactivating a local user, or reactivating them through the admin API. The +deactivation can be triggered either by the user themselves or a server admin. The module +is passed the Matrix ID of the user whose status is changed, as well as a `deactivated` +boolean that is `True` if the user is being deactivated and `False` if they're being +reactivated, and a `by_admin` boolean that is `True` if the deactivation was triggered by +a server admin (and `False` otherwise). This latter `by_admin` boolean is always `True` +if the user is being reactivated, as this operation can only be performed through the +admin API. + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d2bb3d420..6f3623c88 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -163,7 +163,7 @@ presence: # For example, for room version 1, default_room_version should be set # to "1". # -#default_room_version: "6" +#default_room_version: "9" # The GC threshold parameters to pass to `gc.set_threshold`, if defined # diff --git a/docs/structured_logging.md b/docs/structured_logging.md index 14db85f58..805c86765 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999. ## Upgrading from legacy structured logging configuration -Versions of Synapse prior to v1.23.0 included a custom structured logging -configuration which is deprecated. It used a `structured: true` flag and -configured `drains` instead of ``handlers`` and `formatters`. +Versions of Synapse prior to v1.54.0 automatically converted the legacy +structured logging configuration, which was deprecated in v1.23.0, to the standard +library logging configuration. -Synapse currently automatically converts the old configuration to the new -configuration, but this will be removed in a future version of Synapse. The -following reference can be used to update your configuration. Based on the drain -`type`, we can pick a new handler: +The following reference can be used to update your configuration. Based on the +drain `type`, we can pick a new handler: 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` diff --git a/docs/upgrade.md b/docs/upgrade.md index 477d7d0e8..f9be3ac6b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.54.0 + +## Legacy structured logging configuration removal + +This release removes support for the `structured: true` logging configuration +which was deprecated in Synapse v1.23.0. If your logging configuration contains +`structured: true` then it should be modified based on the +[structured logging documentation](structured_logging.md). + # Upgrading to v1.53.0 ## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location` @@ -157,7 +166,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) within the Twisted library. We do not believe Synapse is affected by this vulnerability, though we advise server administrators who installed Synapse via pip to upgrade Twisted -with `pip install --upgrade Twisted` as a matter of good practice. The Docker image +with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the updated library. diff --git a/docs/workers.md b/docs/workers.md index dadde4d72..b0f8599ef 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -178,8 +178,11 @@ recommend the use of `systemd` where available: for information on setting up ### `synapse.app.generic_worker` -This worker can handle API requests matching the following regular -expressions: +This worker can handle API requests matching the following regular expressions. +These endpoints can be routed to any worker. If a worker is set up to handle a +stream then, for maximum efficiency, additional endpoints should be routed to that +worker: refer to the [stream writers](#stream-writers) section below for further +information. # Sync requests ^/_matrix/client/(v2_alpha|r0|v3)/sync$ @@ -209,7 +212,6 @@ expressions: ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query - ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request @@ -222,22 +224,25 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ + ^/_matrix/client/(r0|v3|unstable)/joined_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ + ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + # Encryption requests + ^/_matrix/client/(r0|v3|unstable)/keys/query$ + ^/_matrix/client/(r0|v3|unstable)/keys/changes$ + ^/_matrix/client/(r0|v3|unstable)/keys/claim$ + ^/_matrix/client/(r0|v3|unstable)/room_keys/ + # Registration/login requests ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(r0|v3|unstable)/register$ @@ -251,6 +256,20 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ + # Device requests + ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ + + # Account data requests + ^/_matrix/client/(r0|v3|unstable)/.*/tags + ^/_matrix/client/(r0|v3|unstable)/.*/account_data + + # Receipts requests + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers + + # Presence requests + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + Additionally, the following REST endpoints can be handled for GET requests: @@ -330,12 +349,10 @@ Additionally, there is *experimental* support for moving writing of specific streams (such as events) off of the main process to a particular worker. (This is only supported with Redis-based replication.) -Currently supported streams are `events` and `typing`. - To enable this, the worker must have a HTTP replication listener configured, -have a `worker_name` and be listed in the `instance_map` config. For example to -move event persistence off to a dedicated worker, the shared configuration would -include: +have a `worker_name` and be listed in the `instance_map` config. The same worker +can handle multiple streams. For example, to move event persistence off to a +dedicated worker, the shared configuration would include: ```yaml instance_map: @@ -347,6 +364,12 @@ stream_writers: events: event_persister1 ``` +Some of the streams have associated endpoints which, for maximum efficiency, should +be routed to the workers handling that stream. See below for the currently supported +streams and the endpoints associated with them: + +##### The `events` stream + The `events` stream also experimentally supports having multiple writers, where work is sharded between them by room ID. Note that you *must* restart all worker instances when adding or removing event persisters. An example `stream_writers` @@ -359,6 +382,43 @@ stream_writers: - event_persister2 ``` +##### The `typing` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `typing` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing + +##### The `to_device` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `to_device` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ + +##### The `account_data` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `account_data` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags + ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data + +##### The `receipts` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `receipts` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers + +##### The `presence` stream + +The following endpoints should be routed directly to the workers configured as +stream writers for the `presence` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ + #### Background tasks There is also *experimental* support for moving background tasks to a separate diff --git a/mypy.ini b/mypy.ini index 63848d664..38ff78760 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,14 +31,11 @@ exclude = (?x) |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ |tests/api/test_auth.py @@ -78,16 +75,12 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/client/test_account.py - |tests/rest/client/test_events.py |tests/rest/client/test_filter.py - |tests/rest/client/test_groups.py - |tests/rest/client/test_register.py |tests/rest/client/test_report_event.py |tests/rest/client/test_rooms.py |tests/rest/client/test_third_party_rules.py |tests/rest/client/test_transactions.py |tests/rest/client/test_typing.py - |tests/rest/client/utils.py |tests/rest/key/v2/test_remote_key_resource.py |tests/rest/media/v1/test_base.py |tests/rest/media/v1/test_media_storage.py @@ -256,7 +249,7 @@ disallow_untyped_defs = True [mypy-tests.rest.admin.*] disallow_untyped_defs = True -[mypy-tests.rest.client.test_directory] +[mypy-tests.rest.client.*] disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] diff --git a/pyproject.toml b/pyproject.toml index 963f149c6..c9cd0cf6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,15 @@ exclude = ''' )/ ) ''' + +[tool.isort] +line_length = 88 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"] +default_section = "THIRDPARTY" +known_first_party = ["synapse"] +known_tests = ["tests"] +known_twisted = ["twisted", "OpenSSL"] +multi_line_output = 3 +include_trailing_comma = true +combine_as_imports = true + diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index c764011d6..493558ad6 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -35,7 +35,7 @@ CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing y https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit -tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) +python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) echo echo "--------------------------" diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index e08ffedaf..0aecb3daf 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -5,7 +5,7 @@ # It makes a Synapse image which represents the current checkout, # builds a synapse-complement image on top, then runs tests with it. # -# By default the script will fetch the latest Complement master branch and +# By default the script will fetch the latest Complement main branch and # run tests with that. This can be overridden to use a custom Complement # checkout by setting the COMPLEMENT_DIR environment variable to the # filepath of a local Complement checkout or by setting the COMPLEMENT_REF @@ -32,7 +32,7 @@ cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then - COMPLEMENT_REF=${COMPLEMENT_REF:-master} + COMPLEMENT_REF=${COMPLEMENT_REF:-main} echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..." wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz tar -xzf ${COMPLEMENT_REF}.tar.gz diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 5c6453d77..f43676afa 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -44,7 +44,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs): - store = hs.get_datastore() + store = hs.get_datastores().main async def run_background_updates(): await store.db_pool.updates.run_background_updates(sleep=False) diff --git a/setup.cfg b/setup.cfg index e5ceb7ed1..6213f3265 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[trial] -test_suite = tests - [check-manifest] ignore = .git-blame-ignore-revs @@ -10,23 +7,3 @@ ignore = pylint.cfg tox.ini -[flake8] -# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes -# for error codes. The ones we ignore are: -# W503: line break before binary operator -# W504: line break after binary operator -# E203: whitespace before ':' (which is contrary to pep8?) -# E731: do not assign a lambda expression, use a def -# E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 - -[isort] -line_length = 88 -sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER -default_section=THIRDPARTY -known_first_party = synapse -known_tests=tests -known_twisted=twisted,OpenSSL -multi_line_output=3 -include_trailing_comma=true -combine_as_imports=true diff --git a/setup.py b/setup.py index d0511c767..26f465034 100755 --- a/setup.py +++ b/setup.py @@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ ] CONDITIONAL_REQUIREMENTS["mypy"] = [ - "mypy==0.910", - "mypy-zope==0.3.2", + "mypy==0.931", + "mypy-zope==0.3.5", "types-bleach>=4.1.0", "types-jsonschema>=3.2.0", "types-opentracing>=2.4.2", @@ -165,6 +165,7 @@ setup( "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 0eaef0049..344d55cce 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]): def __copy__(self: _SD) -> _SD: ... @classmethod @overload - def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + def fromkeys( + cls, seq: Iterable[_T_h], value: None = ... + ) -> SortedDict[_T_h, None]: ... @classmethod @overload def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... - def keys(self) -> SortedKeysView[_KT]: ... - def items(self) -> SortedItemsView[_KT, _VT]: ... - def values(self) -> SortedValuesView[_VT]: ... + # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so + # `Sorted{Keys,Items,Values}View` are no longer compatible with them. + # See https://github.com/python/typeshed/issues/6837 + def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override] + def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override] + def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override] @overload def pop(self, key: _KT) -> _VT: ... @overload diff --git a/synapse/__init__.py b/synapse/__init__.py index 903f2e815..b21e1ed0f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.53.0" +__version__ = "1.54.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b9afaeecd..64fc05513 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -60,7 +60,7 @@ class Auth: def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 08fe160c9..22348d2d8 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AuthBlocking: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._hs_disabled = hs.config.server.hs_disabled diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index d087c816d..cb532d723 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -22,6 +22,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, TypeVar, @@ -150,7 +151,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID: class Filtering: def __init__(self, hs: "HomeServer"): self._hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) @@ -294,7 +295,7 @@ class FilterCollection: class Filter: def __init__(self, hs: "HomeServer", filter_json: JsonDict): self._hs = hs - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.filter_json = filter_json self.limit = filter_json.get("limit", 10) @@ -361,10 +362,10 @@ class Filter: return self._check_fields(field_matchers) else: content = event.get("content") - # Content is assumed to be a dict below, so ensure it is. This should + # Content is assumed to be a mapping below, so ensure it is. This should # always be true for events, but account_data has been allowed to # have non-dict content. - if not isinstance(content, dict): + if not isinstance(content, Mapping): content = {} sender = event.get("sender", None) diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index ee51480a9..334c3d2c1 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -15,13 +15,13 @@ import logging import sys from typing import Container -from synapse import python_dependencies # noqa: E402 +from synapse.util import check_dependencies logger = logging.getLogger(__name__) try: - python_dependencies.check_requirements() -except python_dependencies.DependencyException as e: + check_dependencies.check_requirements() +except check_dependencies.DependencyException as e: sys.stderr.writelines( e.message # noqa: B306, DependencyException.message is a property ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 452c0c09d..3e59805ba 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -448,7 +448,7 @@ async def start(hs: "HomeServer") -> None: # It is now safe to start your Synapse. hs.start_listening() - hs.get_datastore().db_pool.start_profiling() + hs.get_datastores().main.db_pool.start_profiling() hs.get_pusherpool().start() # Log when we start the shut down process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index aadc882bf..1536a4272 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -142,7 +142,7 @@ class KeyUploadServlet(RestServlet): """ super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bfb30003c..a6789a840 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -59,7 +59,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource @@ -70,6 +69,7 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.util.check_dependencies import check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module @@ -372,7 +372,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer: await _base.start(hs) - hs.get_datastore().db_pool.updates.start_doing_background_updates() + hs.get_datastores().main.db_pool.updates.start_doing_background_updates() register_start(start) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 899dba5c3..40dbdace8 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -82,7 +82,7 @@ async def phone_stats_home( # General statistics # - store = hs.get_datastore() + store = hs.get_datastores().main stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server.server_context @@ -170,18 +170,22 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + clock.looping_call( + hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + ) # monthly active user limiting functionality - clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) - hs.get_datastore().reap_monthly_active_users() + clock.looping_call( + hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + ) + hs.get_datastores().main.reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} reserved_users: Sized = () - store = hs.get_datastore() + store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index a340a8c9c..4d3f8e492 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -31,6 +31,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Type for the `device_one_time_key_counts` field in an appservice transaction +# user ID -> {device ID -> {algorithm -> count}} +TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]] + +# Type for the `device_unused_fallback_keys` field in an appservice transaction +# user ID -> {device ID -> [algorithm]} +TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]] + class ApplicationServiceState(Enum): DOWN = "down" @@ -72,6 +80,7 @@ class ApplicationService: rate_limited: bool = True, ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, + msc3202_transaction_extensions: bool = False, ): self.token = token self.url = ( @@ -84,6 +93,7 @@ class ApplicationService: self.id = id self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral + self.msc3202_transaction_extensions = msc3202_transaction_extensions if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -339,12 +349,16 @@ class AppServiceTransaction: events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral self.to_device_messages = to_device_messages + self.one_time_key_counts = one_time_key_counts + self.unused_fallback_keys = unused_fallback_keys async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -359,6 +373,8 @@ class AppServiceTransaction: events=self.events, ephemeral=self.ephemeral, to_device_messages=self.to_device_messages, + one_time_key_counts=self.one_time_key_counts, + unused_fallback_keys=self.unused_fallback_keys, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 73be7ff3d..a0ea958af 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -19,6 +19,11 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient @@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: - from synapse.appservice import ApplicationService from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient): events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, txn_id: Optional[int] = None, ) -> bool: """ @@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it - body: Dict[str, List[JsonDict]] = {"events": serialized_events} + body: JsonDict = {"events": serialized_events} if service.supports_ephemeral: body.update( { @@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient): } ) + if service.msc3202_transaction_extensions: + if one_time_key_counts: + body[ + "org.matrix.msc3202.device_one_time_key_counts" + ] = one_time_key_counts + if unused_fallback_keys: + body[ + "org.matrix.msc3202.device_unused_fallback_keys" + ] = unused_fallback_keys + try: await self.put_json( uri=uri, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c42fa32ff..72417151b 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -54,12 +54,19 @@ from typing import ( Callable, Collection, Dict, + Iterable, List, Optional, Set, + Tuple, ) -from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background @@ -92,11 +99,11 @@ class ApplicationServiceScheduler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) + self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) async def start(self) -> None: logger.info("Starting appservice scheduler") @@ -153,7 +160,9 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): + def __init__( + self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer" + ): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} @@ -165,6 +174,10 @@ class _ServiceQueuer: self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock + self._msc3202_transaction_extensions_enabled: bool = ( + hs.config.experimental.msc3202_transaction_extensions + ) + self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one @@ -202,15 +215,84 @@ class _ServiceQueuer: if not events and not ephemeral and not to_device_messages_to_send: return + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + + if ( + self._msc3202_transaction_extensions_enabled + and service.msc3202_transaction_extensions + ): + # Compute the one-time key counts and fallback key usage states + # for the users which are mentioned in this transaction, + # as well as the appservice's sender. + ( + one_time_key_counts, + unused_fallback_keys, + ) = await self._compute_msc3202_otk_counts_and_fallback_keys( + service, events, ephemeral, to_device_messages_to_send + ) + try: await self.txn_ctrl.send( - service, events, ephemeral, to_device_messages_to_send + service, + events, + ephemeral, + to_device_messages_to_send, + one_time_key_counts, + unused_fallback_keys, ) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) + async def _compute_msc3202_otk_counts_and_fallback_keys( + self, + service: ApplicationService, + events: Iterable[EventBase], + ephemerals: Iterable[JsonDict], + to_device_messages: Iterable[JsonDict], + ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]: + """ + Given a list of the events, ephemeral messages and to-device messages, + - first computes a list of application services users that may have + interesting updates to the one-time key counts or fallback key usage. + - then computes one-time key counts and fallback key usages for those users. + Given a list of application service users that are interesting, + compute one-time key counts and fallback key usages for the users. + """ + + # Set of 'interesting' users who may have updates + users: Set[str] = set() + + # The sender is always included + users.add(service.sender) + + # All AS users that would receive the PDUs or EDUs sent to these rooms + # are classed as 'interesting'. + rooms_of_interesting_users: Set[str] = set() + # PDUs + rooms_of_interesting_users.update(event.room_id for event in events) + # EDUs + rooms_of_interesting_users.update( + ephemeral["room_id"] for ephemeral in ephemerals + ) + + # Look up the AS users in those rooms + for room_id in rooms_of_interesting_users: + users.update( + await self._store.get_app_service_users_in_room(room_id, service) + ) + + # Add recipients of to-device messages. + # device_message["user_id"] is the ID of the recipient. + users.update(device_message["user_id"] for device_message in to_device_messages) + + # Compute and return the counts / fallback key usage states + otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users) + unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users) + return otk_counts, unused_fbks + class _TransactionController: """Transaction manager. @@ -238,6 +320,8 @@ class _TransactionController: events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None, + one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, + unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -248,6 +332,10 @@ class _TransactionController: events: The persistent events to include in the transaction. ephemeral: The ephemeral events to include in the transaction. to_device_messages: The to-device messages to include in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -255,6 +343,8 @@ class _TransactionController: events=events, ephemeral=ephemeral or [], to_device_messages=to_device_messages or [], + one_time_key_counts=one_time_key_counts or {}, + unused_fallback_keys=unused_fallback_keys or {}, ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7fad2e042..439bfe152 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -166,6 +166,16 @@ def _load_appservice( supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + # Opt-in flag for the MSC3202-specific transactional behaviour. + # When enabled, appservice transactions contain the following information: + # - device One-Time Key counts + # - device unused fallback key usage states + msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) + if not isinstance(msc3202_transaction_extensions, bool): + raise ValueError( + "The `org.matrix.msc3202` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -174,8 +184,9 @@ def _load_appservice( hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], - supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, + supports_ephemeral=supports_ephemeral, + msc3202_transaction_extensions=msc3202_transaction_extensions, ) diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 387ac6d11..9a68da9c3 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 09d692d9a..41338b39d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -41,20 +41,12 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) - # MSC3283 (set displayname, avatar_url and change 3pid capabilities) - self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) - # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # The portion of MSC3202 which is related to device masquerading. - self.msc3202_device_masquerading_enabled: bool = experimental.get( - "msc3202_device_masquerading", False - ) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. @@ -62,5 +54,23 @@ class ExperimentalConfig(Config): "msc2409_to_device_messages_enabled", False ) + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) + + # Portion of MSC3202 related to transaction extensions: + # sending one-time key counts and fallback key usage to application services. + self.msc3202_transaction_extensions: bool = experimental.get( + "msc3202_transaction_extensions", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) + + # experimental support for faster joins over federation (msc2775, msc3706) + # requires a target server with msc3706_enabled enabled. + self.faster_joins_enabled: bool = experimental.get("faster_joins", False) + + # MSC3720 (Account status endpoint) + self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b7145a44a..cbbe22196 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -33,7 +33,6 @@ from twisted.logger import ( globalLogBeginner, ) -from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter @@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option removed in Synapse 1.3.0. You should instead set up a separate log configuration file. """ +STRUCTURED_ERROR = """\ +Support for the structured configuration option was removed in Synapse 1.54.0. +You should instead use the standard logging configuration. See +https://matrix-org.github.io/synapse/v1.54/structured_logging.html +""" + class LoggingConfig(Config): section = "logging" @@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: if not log_config: logging.warning("Loaded a blank logging config?") - # If the old structured logging configuration is being used, convert it to - # the new style configuration. + # If the old structured logging configuration is being used, raise an error. if "structured" in log_config and log_config.get("structured"): - log_config = setup_structured_logging(log_config) + raise ConfigError(STRUCTURED_ERROR) logging.config.dictConfig(log_config) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 1cc26e757..f62292ecf 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,7 +15,7 @@ import attr -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index e783b1131..f7e4f9ef2 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -20,11 +20,11 @@ import attr from synapse.config._util import validate_config from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri +from ..util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider" diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 33104af73..bdb1aac3a 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.config._base import Config -from synapse.python_dependencies import check_requirements +from synapse.util.check_dependencies import check_requirements class RedisConfig(Config): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c964faae..d1e45ecd2 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,8 +20,8 @@ from urllib.request import getproxies_environment # type: ignore import attr from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module from ._base import Config, ConfigError diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ec9d9f65e..43c456d5c 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -17,8 +17,8 @@ import logging from typing import Any, List, Set from synapse.config.sso import SsoAttributeRequirement -from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict +from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc962454..49cd0a4f1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -146,7 +146,7 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "fec0::/10", ] -DEFAULT_ROOM_VERSION = "6" +DEFAULT_ROOM_VERSION = "9" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 21b9a8835..7aff618ea 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -14,7 +14,7 @@ from typing import Set -from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.check_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 72d4a69aa..93d56c077 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _fetch_keys( self, keys_to_fetch: List[_FetchKeyRequest] @@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config async def process_v2_response( diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eca00bc97..621a3efcc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -374,9 +374,9 @@ def _is_membership_change_allowed( return # Require the user to be in the room for membership changes other than join/knock. - if Membership.JOIN != membership and ( - RoomVersion.msc2403_knocking and Membership.KNOCK != membership - ): + # Note that the room version check for knocking is done implicitly by `caller_knocked` + # and the ability to set a membership of `knock` in the first place. + if Membership.JOIN != membership and Membership.KNOCK != membership: # If the user has been invited or has knocked, they are allowed to change their # membership event to leave if ( diff --git a/synapse/events/builder.py b/synapse/events/builder.py index eb39e0ae3..1ea1bb7d3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -189,7 +189,7 @@ class EventBuilderFactory: self.hostname = hs.hostname self.signing_key = hs.signing_key - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5833fee25..46042b2bf 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -101,6 +101,9 @@ class EventContext: As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. + + partial_state: if True, we may be storing this event with a temporary, + incomplete state. """ rejected: Union[bool, str] = False @@ -113,12 +116,15 @@ class EventContext: _current_state_ids: Optional[StateMap[str]] = None _prev_state_ids: Optional[StateMap[str]] = None + partial_state: bool = False + @staticmethod def with_state( state_group: Optional[int], state_group_before_event: Optional[int], current_state_ids: Optional[StateMap[str]], prev_state_ids: Optional[StateMap[str]], + partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": @@ -129,6 +135,7 @@ class EventContext: state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, + partial_state=partial_state, ) @staticmethod @@ -170,6 +177,7 @@ class EventContext: "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, + "partial_state": self.partial_state, } @staticmethod @@ -196,6 +204,7 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], + partial_state=input.get("partial_state", False), ) app_service_id = input["app_service_id"] diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1bb8ca714..dd3104faf 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tupl from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import maybe_awaitable @@ -37,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] +ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -143,7 +146,7 @@ class ThirdPartyEventRules: def __init__(self, hs: "HomeServer"): self.third_party_rules = None - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -154,6 +157,10 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] + self._on_user_deactivation_status_changed_callbacks: List[ + ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -166,6 +173,8 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, + on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -187,6 +196,12 @@ class ThirdPartyEventRules: if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if on_profile_update is not None: + self._on_profile_update_callbacks.append(on_profile_update) + + if on_deactivation is not None: + self._on_user_deactivation_status_changed_callbacks.append(on_deactivation) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -334,9 +349,6 @@ class ThirdPartyEventRules: Args: event_id: The ID of the event. - - Raises: - ModuleFailureError if a callback raised any exception. """ # Bail out early without hitting the store if we don't have any callbacks if len(self._on_new_event_callbacks) == 0: @@ -370,3 +382,41 @@ class ThirdPartyEventRules: state_events[key] = room_state_events[event_id] return state_events + + async def on_profile_update( + self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool + ) -> None: + """Called after the global profile of a user has been updated. Does not include + per-room profile changes. + + Args: + user_id: The user whose profile was changed. + new_profile: The updated profile for the user. + by_admin: Whether the profile update was performed by a server admin. + deactivation: Whether this change was made while deactivating the user. + """ + for callback in self._on_profile_update_callbacks: + try: + await callback(user_id, new_profile, by_admin, deactivation) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_user_deactivation_status_changed( + self, user_id: str, deactivated: bool, by_admin: bool + ) -> None: + """Called after a user has been deactivated or reactivated. + + Args: + user_id: The deactivated user. + deactivated: Whether the user is now deactivated. + by_admin: Whether the deactivation was performed by a server admin. + """ + for callback in self._on_user_deactivation_status_changed_callbacks: + try: + await callback(user_id, deactivated, by_admin) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 243696b35..9386fa29d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -425,6 +425,33 @@ class EventClientSerializer: return serialized_event + def _apply_edit( + self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase + ) -> None: + """Replace the content, preserving existing relations of the serialized event. + + Args: + orig_event: The original event. + serialized_event: The original event, serialized. This is modified. + edit: The event which edits the above. + """ + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = orig_event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + def _inject_bundled_aggregations( self, event: EventBase, @@ -450,26 +477,11 @@ class EventClientSerializer: serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references if aggregations.replace: - # If there is an edit replace the content, preserving existing - # relations. + # If there is an edit, apply it to the event. edit = aggregations.replace + self._apply_edit(event, serialized_event, edit) - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - + # Include information about it in the relations dict. serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, @@ -478,13 +490,22 @@ class EventClientSerializer: # If this event is the start of a thread, include a summary of the replies. if aggregations.thread: + thread = aggregations.thread + + # Don't bundle aggregations as this could recurse forever. + serialized_latest_event = self.serialize_event( + thread.latest_event, time_now, bundle_aggregations=None + ) + # Manually apply an edit, if one exists. + if thread.latest_edit: + self._apply_edit( + thread.latest_event, serialized_latest_event, thread.latest_edit + ) + serialized_aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": self.serialize_event( - aggregations.thread.latest_event, time_now, bundle_aggregations=None - ), - "count": aggregations.thread.count, - "current_user_participated": aggregations.thread.current_user_participated, + "latest_event": serialized_latest_event, + "count": thread.count, + "current_user_participated": thread.current_user_participated, } # Include the bundled aggregations in the event. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 896168c05..41ac49fdc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -39,7 +39,7 @@ class FederationBase: self.server_name = hs.hostname self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._clock = hs.get_clock() async def _check_sigs_and_hash( @@ -47,6 +47,11 @@ class FederationBase: ) -> EventBase: """Checks that event is correctly signed by the sending server. + Also checks the content hash, and redacts the event if there is a mismatch. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + Args: room_version: The room version of the PDU pdu: the event to be checked @@ -55,7 +60,10 @@ class FederationBase: * the original event if the checks pass * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed.""" + + Raises: + SynapseError if the signature check failed. + """ try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) except SynapseError as e: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 74f17aa4d..64e595e74 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1,4 +1,4 @@ -# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2015-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -56,7 +56,7 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -89,6 +89,12 @@ class SendJoinResult: state: List[EventBase] auth_chain: List[EventBase] + # True if 'state' elides non-critical membership events + partial_state: bool + + # if 'partial_state' is set, a list of the servers in the room (otherwise empty) + servers_in_room: List[str] + class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): @@ -413,26 +419,90 @@ class FederationClient(FederationBase): return state_event_ids, auth_event_ids + async def get_room_state( + self, + destination: str, + room_id: str, + event_id: str, + room_version: RoomVersion, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Calls the /state endpoint to fetch the state at a particular point + in the room. + + Any invalid events (those with incorrect or unverifiable signatures or hashes) + are filtered out from the response, and any duplicate events are removed. + + (Size limits and other event-format checks are *not* performed.) + + Note that the result is not ordered, so callers must be careful to process + the events in an order that handles dependencies. + + Returns: + a tuple of (state events, auth events) + """ + result = await self.transport_layer.get_room_state( + room_version, + destination, + room_id, + event_id, + ) + state_events = result.state + auth_events = result.auth_events + + # we may as well filter out any duplicates from the response, to save + # processing them multiple times. (In particular, events may be present in + # `auth_events` as well as `state`, which is redundant). + # + # We don't rely on the sort order of the events, so we can just stick them + # in a dict. + state_event_map = {event.event_id: event for event in state_events} + auth_event_map = { + event.event_id: event + for event in auth_events + if event.event_id not in state_event_map + } + + logger.info( + "Processing from /state: %d state events, %d auth events", + len(state_event_map), + len(auth_event_map), + ) + + valid_auth_events = await self._check_sigs_and_hash_and_fetch( + destination, auth_event_map.values(), room_version + ) + + valid_state_events = await self._check_sigs_and_hash_and_fetch( + destination, state_event_map.values(), room_version + ) + + return valid_state_events, valid_auth_events + async def _check_sigs_and_hash_and_fetch( self, origin: str, pdus: Collection[EventBase], room_version: RoomVersion, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashes of each - one. If a PDU fails its signature check then we check if we have it in - the database and if not then request if from the originating server of - that PDU. + """Checks the signatures and hashes of a list of events. + + If a PDU fails its signature check then we check if we have it in + the database, and if not then request it from the sender's server (if that + is different from `origin`). If that still fails, the event is omitted from + the returned list. If a PDU fails its content hash check then it is redacted. - The given list of PDUs are not modified, instead the function returns + Also runs each event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + + The given list of PDUs are not modified; instead the function returns a new list. Args: - origin - pdu - room_version + origin: The server that sent us these events + pdus: The events to be checked + room_version: the version of the room these events are in Returns: A list of PDUs that have valid signatures and hashes. @@ -463,11 +533,16 @@ class FederationClient(FederationBase): origin: str, room_version: RoomVersion, ) -> Optional[EventBase]: - """Takes a PDU and checks its signatures and hashes. If the PDU fails - its signature check then we check if we have it in the database and if - not then request if from the originating server of that PDU. + """Takes a PDU and checks its signatures and hashes. - If then PDU fails its content hash check then it is redacted. + If the PDU fails its signature check then we check if we have it in the + database; if not, we then request it from sender's server (if that is not the + same as `origin`). If that still fails, we return None. + + If the PDU fails its content hash check, it is redacted. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. Args: origin @@ -540,11 +615,15 @@ class FederationClient(FederationBase): synapse_error = e.to_synapse_error() # There is no good way to detect an "unknown" endpoint. # - # Dendrite returns a 404 (with no body); synapse returns a 400 + # Dendrite returns a 404 (with a body of "404 page not found"); + # Conduit returns a 404 (with no body); and Synapse returns a 400 # with M_UNRECOGNISED. - return e.code == 404 or ( - e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED - ) + # + # This needs to be rather specific as some endpoints truly do return 404 + # errors. + return ( + e.code == 404 and (not e.response or e.response == b"404 page not found") + ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED) async def _try_destination_list( self, @@ -864,23 +943,32 @@ class FederationClient(FederationBase): for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - # double-check that the same create event has ended up in the auth chain + # double-check that the auth chain doesn't include a different create event auth_chain_create_events = [ e.event_id for e in signed_auth if (e.type, e.state_key) == (EventTypes.Create, "") ] - if auth_chain_create_events != [create_event.event_id]: + if auth_chain_create_events and auth_chain_create_events != [ + create_event.event_id + ]: raise InvalidResponseError( "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) ) + if response.partial_state and not response.servers_in_room: + raise InvalidResponseError( + "partial_state was set, but no servers were listed in the room" + ) + return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, + partial_state=response.partial_state, + servers_in_room=response.servers_in_room or [], ) # MSC3083 defines additional error codes for room joins. @@ -918,7 +1006,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -987,7 +1075,7 @@ class FederationClient(FederationBase): except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint if the room uses old-style event IDs. - # Otherwise consider it a legitmate error and raise. + # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() if self._is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.V1: @@ -1048,7 +1136,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. if not self._is_unknown_endpoint(e): raise @@ -1274,61 +1362,6 @@ class FederationClient(FederationBase): # server doesn't give it to us. return None - async def get_space_summary( - self, - destinations: Iterable[str], - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> "FederationSpaceSummaryResult": - """ - Call other servers to get a summary of the given space - - - Args: - destinations: The remote servers. We will try them in turn, omitting any - that have been blacklisted. - - room_id: ID of the space to be queried - - suggested_only: If true, ask the remote server to only return children - with the "suggested" flag set - - max_rooms_per_space: A limit on the number of children to return for each - space - - exclude_rooms: A list of room IDs to tell the remote server to skip - - Returns: - a parsed FederationSpaceSummaryResult - - Raises: - SynapseError if we were unable to get a valid summary from any of the - remote servers - """ - - async def send_request(destination: str) -> FederationSpaceSummaryResult: - res = await self.transport_layer.get_space_summary( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - exclude_rooms=exclude_rooms, - ) - - try: - return FederationSpaceSummaryResult.from_json_dict(res) - except ValueError as e: - raise InvalidResponseError(str(e)) - - return await self._try_destination_list( - "fetch space summary", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - async def get_room_hierarchy( self, destinations: Iterable[str], @@ -1374,8 +1407,8 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, - # fallback to the unstable endpoint. Otherwise consider it a - # legitmate error and raise. + # fallback to the unstable endpoint. Otherwise, consider it a + # legitimate error and raise. if not self._is_unknown_endpoint(e): raise @@ -1400,10 +1433,8 @@ class FederationClient(FederationBase): if any(not isinstance(e, dict) for e in children_state): raise InvalidResponseError("Invalid event in 'children_state' list") try: - [ - FederationSpaceSummaryEventResult.from_json_dict(e) - for e in children_state - ] + for child_state in children_state: + _validate_hierarchy_event(child_state) except ValueError as e: raise InvalidResponseError(str(e)) @@ -1425,62 +1456,12 @@ class FederationClient(FederationBase): return room, children_state, children, inaccessible_children - try: - result = await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) - except SynapseError as e: - # If an unexpected error occurred, re-raise it. - if e.code != 502: - raise - - logger.debug( - "Couldn't fetch room hierarchy, falling back to the spaces API" - ) - - # Fallback to the old federation API and translate the results if - # no servers implement the new API. - # - # The algorithm below is a bit inefficient as it only attempts to - # parse information for the requested room, but the legacy API may - # return additional layers. - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) - - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - result = (requested_room, children_events, children, ()) + result = await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) # Cache the result to avoid fetching data over federation every time. self._get_room_hierarchy_cache[(room_id, suggested_only)] = result @@ -1526,6 +1507,64 @@ class FederationClient(FederationBase): except ValueError as e: raise InvalidResponseError(str(e)) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> Tuple[JsonDict, List[str]]: + """Retrieves account statuses for a given list of users on a given remote + homeserver. + + If the request fails for any reason, all user IDs for this destination are marked + as failed. + + Args: + destination: the destination to contact + user_ids: the user ID(s) for which to request account status(es) + + Returns: + The account statuses, as well as the list of user IDs for which it was not + possible to retrieve a status. + """ + try: + res = await self.transport_layer.get_account_status(destination, user_ids) + except Exception: + # If the query failed for any reason, mark all the users as failed. + return {}, user_ids + + statuses = res.get("account_statuses", {}) + failures = res.get("failures", []) + + if not isinstance(statuses, dict) or not isinstance(failures, list): + # Make sure we're not feeding back malformed data back to the caller. + logger.warning( + "Destination %s responded with malformed data to account_status query", + destination, + ) + return {}, user_ids + + for user_id in user_ids: + # Any account whose status is missing is a user we failed to receive the + # status of. + if user_id not in statuses and user_id not in failures: + failures.append(user_id) + + # Filter out any user ID that doesn't belong to the remote server that sent its + # status (or failure). + def filter_user_id(user_id: str) -> bool: + try: + return UserID.from_string(user_id).domain == destination + except SynapseError: + # If the user ID doesn't parse, ignore it. + return False + + filtered_statuses = dict( + # item is a (key, value) tuple, so item[0] is the user ID. + filter(lambda item: filter_user_id(item[0]), statuses.items()) + ) + + filtered_failures = list(filter(filter_user_id, failures)) + + return filtered_statuses, filtered_failures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: @@ -1564,89 +1603,34 @@ class TimestampToEventResponse: return cls(event_id, origin_server_ts, d) -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryEventResult: - """Represents a single event in the result of a successful get_space_summary call. +def _validate_hierarchy_event(d: JsonDict) -> None: + """Validate an event within the result of a /hierarchy request - It's essentially just a serialised event object, but we do a bit of parsing and - validation in `from_json_dict` and store some of the validated properties in - object attributes. + Args: + d: json object to be parsed + + Raises: + ValueError if d is not a valid event """ - event_type: str - room_id: str - state_key: str - via: Sequence[str] + event_type = d.get("type") + if not isinstance(event_type, str): + raise ValueError("Invalid event: 'event_type' must be a str") - # the raw data, including the above keys - data: JsonDict + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": - """Parse an event within the result of a /spaces/ request + state_key = d.get("state_key") + if not isinstance(state_key, str): + raise ValueError("Invalid event: 'state_key' must be a str") - Args: - d: json object to be parsed + content = d.get("content") + if not isinstance(content, dict): + raise ValueError("Invalid event: 'content' must be a dict") - Raises: - ValueError if d is not a valid event - """ - - event_type = d.get("type") - if not isinstance(event_type, str): - raise ValueError("Invalid event: 'event_type' must be a str") - - room_id = d.get("room_id") - if not isinstance(room_id, str): - raise ValueError("Invalid event: 'room_id' must be a str") - - state_key = d.get("state_key") - if not isinstance(state_key, str): - raise ValueError("Invalid event: 'state_key' must be a str") - - content = d.get("content") - if not isinstance(content, dict): - raise ValueError("Invalid event: 'content' must be a dict") - - via = content.get("via") - if not isinstance(via, Sequence): - raise ValueError("Invalid event: 'via' must be a list") - if any(not isinstance(v, str) for v in via): - raise ValueError("Invalid event: 'via' must be a list of strings") - - return cls(event_type, room_id, state_key, via, d) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class FederationSpaceSummaryResult: - """Represents the data returned by a successful get_space_summary call.""" - - rooms: List[JsonDict] - events: Sequence[FederationSpaceSummaryEventResult] - - @classmethod - def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": - """Parse the result of a /spaces/ request - - Args: - d: json object to be parsed - - Raises: - ValueError if d is not a valid /spaces/ response - """ - rooms = d.get("rooms") - if not isinstance(rooms, List): - raise ValueError("'rooms' must be a list") - if any(not isinstance(r, dict) for r in rooms): - raise ValueError("Invalid room in 'rooms' list") - - events = d.get("events") - if not isinstance(events, Sequence): - raise ValueError("'events' must be a list") - if any(not isinstance(e, dict) for e in events): - raise ValueError("Invalid event in 'events' list") - parsed_events = [ - FederationSpaceSummaryEventResult.from_json_dict(e) for e in events - ] - - return cls(rooms, parsed_events) + via = content.get("via") + if not isinstance(via, Sequence): + raise ValueError("Invalid event: 'via' must be a list") + if any(not isinstance(v, str) for v in via): + raise ValueError("Invalid event: 'via' must be a list of strings") diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 720d7bd74..6106a486d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -228,7 +228,7 @@ class FederationSender(AbstractFederationSender): self.hs = hs self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8152e80b8..c8768f22b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -76,7 +76,7 @@ class PerDestinationQueue: ): self._server_name = hs.hostname self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config @@ -381,7 +381,8 @@ class PerDestinationQueue: ) ) - if self._last_successful_stream_ordering is None: + _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering + if _tmp_last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database ­ we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -391,11 +392,12 @@ class PerDestinationQueue: self._catching_up = False return + last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering + # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, - self._last_successful_stream_ordering, + self._destination, last_successful_stream_ordering ) if not event_ids: @@ -403,7 +405,7 @@ class PerDestinationQueue: # of a race condition, so we check that no new events have been # skipped due to us being in catch-up mode - if self._catchup_last_skipped > self._last_successful_stream_ordering: + if self._catchup_last_skipped > last_successful_stream_ordering: # another event has been skipped because we were in catch-up mode continue @@ -470,7 +472,7 @@ class PerDestinationQueue: # offline if ( p.internal_metadata.stream_ordering - < self._last_successful_stream_ordering + < last_successful_stream_ordering ): continue @@ -513,12 +515,11 @@ class PerDestinationQueue: # from the *original* PDU, rather than the PDU(s) we actually # send. This is because we use it to mark our position in the # queue of missed PDUs to process. - self._last_successful_stream_ordering = ( - pdu.internal_metadata.stream_ordering - ) + last_successful_stream_ordering = pdu.internal_metadata.stream_ordering + self._last_successful_stream_ordering = last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering + self._destination, last_successful_stream_ordering ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 742ee5725..0c1cad86a 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -53,7 +53,7 @@ class TransactionManager: def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8782586cd..de6e5f44f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,4 +1,4 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,17 +60,17 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname self.client = hs.get_federation_http_client() + self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: - """Requests all state for a given room from the given server at the - given event. Returns the state's event_id's + """Requests the IDs of all state for a given room at the given event. Args: destination: The host name of the remote homeserver we want to get the state from. - context: The name of the context we want the state of + room_id: the room we want the state of event_id: The event we want the context at. Returns: @@ -86,6 +86,29 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) + async def get_room_state( + self, room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> "StateRequestResponse": + """Requests the full state for a given room at the given event. + + Args: + room_version: the version of the room (required to build the event objects) + destination: The host name of the remote homeserver we want + to get the state from. + room_id: the room we want the state of + event_id: The event we want the context at. + + Returns: + Results in a dict received from the remote homeserver. + """ + path = _create_v1_path("/state/%s", room_id) + return await self.client.get_json( + destination, + path=path, + args={"event_id": event_id}, + parser=_StateParser(room_version), + ) + async def get_event( self, destination: str, event_id: str, timeout: Optional[int] = None ) -> JsonDict: @@ -235,8 +258,9 @@ class TransportLayerClient: args: dict, retry_on_dns_fail: bool, ignore_backoff: bool = False, + prefix: str = FEDERATION_V1_PREFIX, ) -> JsonDict: - path = _create_v1_path("/query/%s", query_type) + path = _create_path(prefix, "/query/%s", query_type) return await self.client.get_json( destination=destination, @@ -336,10 +360,15 @@ class TransportLayerClient: content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + query_params: Dict[str, str] = {} + if self._faster_joins_enabled: + # lazy-load state on join + query_params["org.matrix.msc3706.partial_state"] = "true" return await self.client.put_json( destination=destination, path=path, + args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, @@ -1150,39 +1179,6 @@ class TransportLayerClient: return await self.client.get_json(destination=destination, path=path) - async def get_space_summary( - self, - destination: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: List[str], - ) -> JsonDict: - """ - Args: - destination: The remote server - room_id: The room ID to ask about. - suggested_only: if True, only suggested rooms will be returned - max_rooms_per_space: an optional limit to the number of children to be - returned per space - exclude_rooms: a list of any rooms we can skip - """ - # TODO When switching to the stable endpoint, use GET instead of POST. - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id - ) - - params = { - "suggested_only": suggested_only, - "exclude_rooms": exclude_rooms, - } - if max_rooms_per_space is not None: - params["max_rooms_per_space"] = max_rooms_per_space - - return await self.client.post_json( - destination=destination, path=path, data=params - ) - async def get_room_hierarchy( self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: @@ -1219,6 +1215,22 @@ class TransportLayerClient: args={"suggested_only": "true" if suggested_only else "false"}, ) + async def get_account_status( + self, destination: str, user_ids: List[str] + ) -> JsonDict: + """ + Args: + destination: The remote server. + user_ids: The user ID(s) for which to request account status(es). + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status" + ) + + return await self.client.post_json( + destination=destination, path=path, data={"user_ids": user_ids} + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ @@ -1271,6 +1283,20 @@ class SendJoinResponse: # "event" is not included in the response. event: Optional[EventBase] = None + # The room state is incomplete + partial_state: bool = False + + # List of servers in the room + servers_in_room: Optional[List[str]] = None + + +@attr.s(slots=True, auto_attribs=True) +class StateRequestResponse: + """The parsed response of a `/state` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: @@ -1297,6 +1323,32 @@ def _event_list_parser( events.append(event) +@ijson.coroutine +def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the partial_state field in send_join responses + """ + while True: + val = yield + if not isinstance(val, bool): + raise TypeError("partial_state must be a boolean") + response.partial_state = val + + +@ijson.coroutine +def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the servers_in_room field in send_join responses + """ + while True: + val = yield + if not isinstance(val, list) or any(not isinstance(x, str) for x in val): + raise TypeError("servers_in_room must be a list of strings") + response.servers_in_room = val + + class SendJoinParser(ByteParser[SendJoinResponse]): """A parser for the response to `/send_join` requests. @@ -1308,44 +1360,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], [], {}) + self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version + self._coros = [] # The V1 API has the shape of `[200, {...}]`, which we handle by # prefixing with `item.*`. prefix = "item." if v1_api else "" - self._coro_state = ijson.items_coro( - _event_list_parser(room_version, self._response.state), - prefix + "state.item", - use_float=True, - ) - self._coro_auth = ijson.items_coro( - _event_list_parser(room_version, self._response.auth_events), - prefix + "auth_chain.item", - use_float=True, - ) - # TODO Remove the unstable prefix when servers have updated. - # - # By re-using the same event dictionary this will cause the parsing of - # org.matrix.msc3083.v2.event and event to stomp over each other. - # Generally this should be fine. - self._coro_unstable_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "org.matrix.msc3083.v2.event", - use_float=True, - ) - self._coro_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "event", - use_float=True, - ) + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + use_float=True, + ), + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "org.matrix.msc3083.v2.event", + use_float=True, + ), + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ), + ] + + if not v1_api: + self._coros.append( + ijson.items_coro( + _partial_state_parser(self._response), + "org.matrix.msc3706.partial_state", + use_float="True", + ) + ) + + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "org.matrix.msc3706.servers_in_room", + use_float="True", + ) + ) def write(self, data: bytes) -> int: - self._coro_state.send(data) - self._coro_auth.send(data) - self._coro_unstable_event.send(data) - self._coro_event.send(data) + for c in self._coros: + c.send(data) return len(data) @@ -1355,3 +1425,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]): self._response.event_dict, self._room_version ) return self._response + + +class _StateParser(ByteParser[StateRequestResponse]): + """A parser for the response to `/state` requests. + + Args: + room_version: The version of the room. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion): + self._response = StateRequestResponse([], []) + self._room_version = room_version + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + "pdus.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + "auth_chain.item", + use_float=True, + ), + ] + + def write(self, data: bytes) -> int: + for c in self._coros: + c.send(data) + return len(data) + + def finish(self) -> StateRequestResponse: + return self._response diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index db4fe2c79..67a634790 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import ( ) from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, + FederationAccountStatusServlet, FederationTimestampLookupServlet, ) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES @@ -336,6 +337,13 @@ def register_servlets( ): continue + # Only allow the `/account_status` servlet if msc3720 is enabled + if ( + servletclass == FederationAccountStatusServlet + and not hs.config.experimental.msc3720_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dff2b6835..87e99c7dd 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -55,7 +55,7 @@ class Authenticator: self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index e85a8eda5..aed3d5069 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -110,7 +110,7 @@ class FederationSendServlet(BaseFederationServerServlet): if issue_8631_logger.isEnabledFor(logging.DEBUG): DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] device_list_updates = [ - edu.content + edu.get("content", {}) for edu in transaction_data.get("edus", []) if edu.get("edu_type") in DEVICE_UPDATE_EDUS ] @@ -624,81 +624,6 @@ class FederationVersionServlet(BaseFederationServlet): ) -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - class FederationRoomHierarchyServlet(BaseFederationServlet): PATH = "/hierarchy/(?P[^/]*)" @@ -746,7 +671,7 @@ class RoomComplexityServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() + self._store = self.hs.get_datastores().main async def on_GET( self, @@ -766,6 +691,40 @@ class RoomComplexityServlet(BaseFederationServlet): return 200, complexity +class FederationAccountStatusServlet(BaseFederationServerServlet): + PATH = "/query/account_status" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720" + + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._account_handler = hs.get_account_handler() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + if "user_ids" not in content: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + content["user_ids"], + allow_remote=False, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationSendServlet, FederationEventServlet, @@ -792,9 +751,9 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, - FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, + FederationAccountStatusServlet, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a87896e53..ed26d6a6c 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -140,7 +140,7 @@ class GroupAttestionRenewer: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 449bbc700..4c3a5a6e2 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -45,7 +45,7 @@ MAX_LONG_DESC_LEN = 10000 class GroupsServerWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py new file mode 100644 index 000000000..d5badf635 --- /dev/null +++ b/synapse/handlers/account.py @@ -0,0 +1,144 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, List, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class AccountHandler: + def __init__(self, hs: "HomeServer"): + self._main_store = hs.get_datastores().main + self._is_mine = hs.is_mine + self._federation_client = hs.get_federation_client() + + async def get_account_statuses( + self, + user_ids: List[str], + allow_remote: bool, + ) -> Tuple[JsonDict, List[str]]: + """Get account statuses for a list of user IDs. + + If one or more account(s) belong to remote homeservers, retrieve their status(es) + over federation if allowed. + + Args: + user_ids: The list of accounts to retrieve the status of. + allow_remote: Whether to try to retrieve the status of remote accounts, if + any. + + Returns: + The account statuses as well as the list of users whose statuses could not be + retrieved. + + Raises: + SynapseError if a required parameter is missing or malformed, or if one of + the accounts isn't local to this homeserver and allow_remote is False. + """ + statuses = {} + failures = [] + remote_users: List[UserID] = [] + + for raw_user_id in user_ids: + try: + user_id = UserID.from_string(raw_user_id) + except SynapseError: + raise SynapseError( + 400, + f"Not a valid Matrix user ID: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + if self._is_mine(user_id): + status = await self._get_local_account_status(user_id) + statuses[user_id.to_string()] = status + else: + if not allow_remote: + raise SynapseError( + 400, + f"Not a local user: {raw_user_id}", + Codes.INVALID_PARAM, + ) + + remote_users.append(user_id) + + if allow_remote and len(remote_users) > 0: + remote_statuses, remote_failures = await self._get_remote_account_statuses( + remote_users, + ) + + statuses.update(remote_statuses) + failures += remote_failures + + return statuses, failures + + async def _get_local_account_status(self, user_id: UserID) -> JsonDict: + """Retrieve the status of a local account. + + Args: + user_id: The account to retrieve the status of. + + Returns: + The account's status. + """ + status = {"exists": False} + + userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) + + if userinfo is not None: + status = { + "exists": True, + "deactivated": userinfo.is_deactivated, + } + + return status + + async def _get_remote_account_statuses( + self, remote_users: List[UserID] + ) -> Tuple[JsonDict, List[str]]: + """Send out federation requests to retrieve the statuses of remote accounts. + + Args: + remote_users: The accounts to retrieve the statuses of. + + Returns: + The statuses of the accounts, and a list of accounts for which no status + could be retrieved. + """ + # Group remote users by destination, so we only send one request per remote + # homeserver. + by_destination: Dict[str, List[str]] = {} + for user in remote_users: + if user.domain not in by_destination: + by_destination[user.domain] = [] + + by_destination[user.domain].append(user.to_string()) + + # Retrieve the statuses and failures for remote accounts. + final_statuses: JsonDict = {} + final_failures: List[str] = [] + for destination, users in by_destination.items(): + statuses, failures = await self._federation_client.get_account_status( + destination, + users, + ) + + final_statuses.update(statuses) + final_failures += failures + + return final_statuses, final_failures diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index bad48713b..177b4f899 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: class AccountDataHandler: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() @@ -166,7 +166,7 @@ class AccountDataHandler: class AccountDataEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 87e415df7..9d0975f63 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -43,7 +43,7 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 00ab5e79b..96376963f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index a42c3558e..e6461cc3c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,7 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() self.scheduler = hs.get_application_service_scheduler() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7..3e29c96a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -194,7 +194,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} @@ -1183,7 +1183,7 @@ class AuthHandler: # No password providers were able to handle this 3pid # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( medium, address ) if not user_id: @@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ class PasswordAuthProvider: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ class PasswordAuthProvider: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ class PasswordAuthProvider: get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ class PasswordAuthProvider: return None + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 5d8f6c50a..7163af800 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -61,7 +61,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7a13d76a6..76ae768e6 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -38,6 +38,7 @@ class DeactivateAccountHandler: self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname + self._third_party_rules = hs.get_third_party_event_rules() # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False @@ -135,9 +136,13 @@ class DeactivateAccountHandler: if erase_data: user = UserID.from_string(user_id) # Remove avatar URL from this user - await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + await self._profile_handler.set_avatar_url( + user, requester, "", by_admin, deactivation=True + ) # Remove displayname from this user - await self._profile_handler.set_displayname(user, requester, "", by_admin) + await self._profile_handler.set_displayname( + user, requester, "", by_admin, deactivation=True + ) logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) @@ -160,6 +165,13 @@ class DeactivateAccountHandler: # Remove account data (including ignored users and push rules). await self.store.purge_account_data_for_user(user_id) + # Let modules know the user has been deactivated. + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, + True, + by_admin, + ) + return identity_server_supports_unbinding async def _reject_pending_invites_for_user(self, user_id: str) -> None: @@ -264,6 +276,10 @@ class DeactivateAccountHandler: # Mark the user as active. await self.store.set_user_deactivated_status(user_id, False) + await self._third_party_rules.on_user_deactivation_status_changed( + user_id, False, True + ) + # Add the user to the directory, if necessary. Note that # this must be done after the user is re-activated, because # deactivated users are excluded from the user directory. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 36c05f836..934b5bd73 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -63,7 +63,7 @@ class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state @@ -628,7 +628,7 @@ class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b582266af..4cb725d02 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -43,7 +43,7 @@ class DeviceMessageHandler: Args: hs: server """ - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 50129d728..95f147f0e 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -44,7 +44,7 @@ class DirectoryHandler: self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d4dfddf63..d96456cd4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine @@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.e2e_keys_handler = e2e_keys_handler diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 12614b2c5..52e44a2d4 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -45,7 +45,7 @@ class E2eRoomKeysHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # Used to lock whenever a client is uploading key data. This prevents collisions # between clients trying to upload the details of a new session, given all diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 365063ebd..d441ebb0a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -43,7 +43,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname async def check_auth_rules_from_context( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index bac5de052..97e75e60c 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -134,7 +134,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1b3cf3820..70cd819ee 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,8 +49,8 @@ from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, preserve_fn, - run_in_background, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, @@ -107,7 +107,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.federation_client = hs.get_federation_client() @@ -516,11 +516,20 @@ class FederationHandler: await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, - auth_events=auth_chain, + state_events=state, ) + if ret.partial_state: + await self.store.store_partial_state_room(room_id, ret.servers_in_room) + max_stream_id = await self._federation_event_handler.process_remote_join( - origin, room_id, auth_chain, state, event, room_version_obj + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, ) # We wait here until this instance has seen the events come down @@ -559,7 +568,9 @@ class FederationHandler: # lots of requests for missing prev_events which we do actually # have. Hence we fire off the background task, but don't wait for it. - run_in_background(self._handle_queued_pdus, room_queue) + run_as_background_process( + "handle_queued_pdus", self._handle_queued_pdus, room_queue + ) async def do_knock( self, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9edc7369d..4bd87709f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -95,7 +95,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._storage = hs.get_storage() self._state_store = self._storage.state @@ -397,6 +397,7 @@ class FederationEventHandler: state: List[EventBase], event: EventBase, room_version: RoomVersion, + partial_state: bool, ) -> int: """Persists the events returned by a send_join @@ -412,6 +413,7 @@ class FederationEventHandler: event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + partial_state: True if the state omits non-critical membership events Returns: The stream ID after which all events have been persisted. @@ -419,10 +421,8 @@ class FederationEventHandler: Raises: SynapseError if the response is in some way invalid. """ - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} - create_event = None - for e in auth_events: + for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break @@ -439,11 +439,6 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - # filter out any events we have already seen - seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) - for s in seen_remotes: - event_map.pop(s, None) - # persist the auth chain and state events. # # any invalid events here will be marked as rejected, and we'll carry on. @@ -455,13 +450,19 @@ class FederationEventHandler: # signatures right now doesn't mean that we will *never* be able to, so it # is premature to reject them. # - await self._auth_and_persist_outliers(room_id, event_map.values()) + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state) + ) # and now persist the join event itself. - logger.info("Peristing join-via-remote %s", event) + logger.info( + "Peristing join-via-remote %s (partial_state: %s)", event, partial_state + ) with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( - event, old_state=state + event, + old_state=state, + partial_state=partial_state, ) context = await self._check_event_auth(origin, event, context) @@ -703,6 +704,8 @@ class FederationEventHandler: try: state = await self._resolve_state_at_missing_prevs(origin, event) + # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does + # not return partial state await self._process_received_pdu( origin, event, state=state, backfilled=backfilled ) @@ -1245,6 +1248,16 @@ class FederationEventHandler: """ event_map = {event.event_id: event for event in events} + # filter out any events we have already seen. This might happen because + # the events were eagerly pushed to us (eg, during a room join), or because + # another thread has raced against us since we decided to request the event. + # + # This is just an optimisation, so it doesn't need to be watertight - the event + # persister does another round of deduplication. + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + # XXX: it might be possible to kick this process off in parallel with fetching # the events. while event_map: @@ -1717,31 +1730,22 @@ class FederationEventHandler: event_id: the event for which we are lacking auth events """ try: - remote_event_map = { - e.event_id: e - for e in await self._federation_client.get_event_auth( - destination, room_id, event_id - ) - } + remote_events = await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) return - logger.info("/event_auth returned %i events", len(remote_event_map)) + logger.info("/event_auth returned %i events", len(remote_events)) # `event` may be returned, but we should not yet process it. - remote_event_map.pop(event_id, None) + remote_auth_events = (e for e in remote_events if e.event_id != event_id) - # nor should we reprocess any events we have already seen. - seen_remotes = await self._store.have_seen_events( - room_id, remote_event_map.keys() - ) - for s in seen_remotes: - remote_event_map.pop(s, None) - - await self._auth_and_persist_outliers(room_id, remote_event_map.values()) + await self._auth_and_persist_outliers(room_id, remote_auth_events) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] @@ -1795,6 +1799,7 @@ class FederationEventHandler: prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, + partial_state=context.partial_state, ) async def _run_push_actions_and_persist_event( diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9e270d461..e7a399787 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -63,7 +63,7 @@ def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: class GroupsLocalWorkerHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_list_handler = hs.get_room_list_handler() self.groups_server_handler = hs.get_groups_server_handler() self.transport_client = hs.get_federation_transport_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c83eaea35..57c9fdfe6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -49,7 +49,7 @@ id_server_scheme = "https://" class IdentityHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 346a06ff4..344f20f37 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a0b585f6..4c8958458 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError +from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -75,7 +75,7 @@ class MessageHandler: self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @@ -397,7 +397,7 @@ class EventCreationHandler: self.hs = hs self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -552,10 +552,11 @@ class EventCreationHandler: if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: # this can happen if support is withdrawn for a room version raise UnsupportedRoomVersionError(room_version_id) + room_version_obj = maybe_room_version_obj else: try: room_version_obj = await self.store.get_room_version( @@ -993,6 +994,8 @@ class EventCreationHandler: and full_state_ids_at_event and builder.internal_metadata.is_historical() ): + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. old_state = await self.store.get_events_as_list(full_state_ids_at_event) context = await self.state.compute_event_context(event, old_state=old_state) else: @@ -1147,12 +1150,13 @@ class EventCreationHandler: room_version_id = event.content.get( "room_version", RoomVersions.V1.identifier ) - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: raise UnsupportedRoomVersionError( "Attempt to create a room with unsupported room version %s" % (room_version_id,) ) + room_version_obj = maybe_room_version_obj else: room_version_obj = await self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 8f71d975e..593a2aac6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -273,7 +273,7 @@ class OidcProvider: token_generator: "OidcSessionTokenGenerator", provider: OidcProviderConfig, ): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._token_generator = token_generator diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 973f26296..5c01a426f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -127,7 +127,7 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state self.clock = hs.get_clock() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 067c43ae4..c155098be 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): Returns: dict: `user_id` -> `UserPresenceState` """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } + states = {} + missing = [] + for user_id in user_ids: + state = self.user_to_current_state.get(user_id, None) + if state: + states[user_id] = state + else: + missing.append(user_id) - missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) + for user_id in missing: + # if user has no state in database, create the state + if not res.get(user_id, None): + new_state = UserPresenceState.default(user_id) + states[user_id] = new_state + self.user_to_current_state[user_id] = new_state return states @@ -1539,7 +1541,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36e3ad2ba..6554c0d3c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -54,7 +54,7 @@ class ProfileHandler: PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs @@ -71,6 +71,8 @@ class ProfileHandler: self.server_name = hs.config.server.server_name + self._third_party_rules = hs.get_third_party_event_rules() + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS @@ -171,6 +173,7 @@ class ProfileHandler: requester: Requester, new_displayname: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set the displayname of a user @@ -179,6 +182,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_displayname: The displayname to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -227,6 +231,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: @@ -261,6 +269,7 @@ class ProfileHandler: requester: Requester, new_avatar_url: str, by_admin: bool = False, + deactivation: bool = False, ) -> None: """Set a new avatar URL for a user. @@ -269,6 +278,7 @@ class ProfileHandler: requester: The user attempting to make this change. new_avatar_url: The avatar URL to give this user. by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -315,6 +325,10 @@ class ProfileHandler: target_user.to_string(), profile ) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + await self._update_join_states(requester, target_user) @cached() diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 29562d9e3..9da143754 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 704ba7fd6..5a7d0fdb7 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler: def __init__(self, hs: "HomeServer"): self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs @@ -164,7 +164,7 @@ class ReceiptsHandler: class ReceiptEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.config = hs.config @staticmethod diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0a4a48696..e87120fce 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -86,7 +86,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -327,12 +327,12 @@ class RegistrationHandler: if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self.store.generate_user_id() - user = UserID(localpart, self.hs.hostname) + generated_localpart = await self.store.generate_user_id() + user = UserID(generated_localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) if generate_display_name: - default_display_name = localpart + default_display_name = generated_localpart try: await self.register_with_store( user_id=user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2a7a12f04..e79c08126 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -105,7 +105,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -1125,7 +1125,7 @@ class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state @@ -1256,7 +1256,7 @@ class RoomContextHandler: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1396,7 +1396,7 @@ class TimestampLookupHandler: class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def get_new_events( self, @@ -1486,7 +1486,7 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def shutdown_room( self, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 94382158e..d237fd207 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 1a33211a1..f3577b5d5 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -49,7 +49,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index bf1a47efb..a582837cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer: Linearizer = Linearizer(name="member") + self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -500,25 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): key = (room_id,) - with (await self.member_linearizer.queue(key)): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - ) + as_id = object() + if requester.app_service: + as_id = requester.app_service.id + + # We first linearise by the application service (to try to limit concurrent joins + # by application services), and then by room ID. + with (await self.member_as_limiter.queue(as_id)): + with (await self.member_linearizer.queue(key)): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + historical=historical, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + ) return result diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4844b69a0..55c2cbdba 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -15,7 +15,6 @@ import itertools import logging import re -from collections import deque from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -90,7 +89,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -107,153 +106,6 @@ class RoomSummaryHandler: "get_room_hierarchy", ) - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) # historical - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - async def get_room_hierarchy( self, requester: Requester, @@ -398,8 +250,6 @@ class RoomSummaryHandler: None, room_id, suggested_only, - # Do not limit the maximum children. - max_children=None, ) # Otherwise, attempt to use information for federation. @@ -488,74 +338,6 @@ class RoomSummaryHandler: return result - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - # Set a limit on the number of rooms to return. - if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: - max_rooms_per_space = MAX_ROOMS_PER_SPACE - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - async def get_federation_hierarchy( self, origin: str, @@ -579,7 +361,7 @@ class RoomSummaryHandler: The JSON hierarchy dictionary. """ root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None + None, origin, requested_room_id, suggested_only ) if root_room_entry is None: # Room is inaccessible to the requesting server. @@ -600,7 +382,7 @@ class RoomSummaryHandler: continue room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 + None, origin, room_id, suggested_only, include_children=False ) # If the room is accessible, include it in the results. # @@ -626,7 +408,7 @@ class RoomSummaryHandler: origin: Optional[str], room_id: str, suggested_only: bool, - max_children: Optional[int], + include_children: bool = True, ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -641,9 +423,8 @@ class RoomSummaryHandler: room_id: The room ID to summarize. suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. A value of None - means no limit. + include_children: + Whether to include the events of any children. Returns: A room entry if the room should be returned. None, otherwise. @@ -653,9 +434,8 @@ class RoomSummaryHandler: room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + # If the room is not a space return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or not include_children: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -665,14 +445,6 @@ class RoomSummaryHandler: # we only care about suggested children child_events = filter(_is_suggested_child_event, child_events) - # TODO max_children is legacy code for the /spaces endpoint. - if max_children is not None: - child_iter: Iterable[EventBase] = itertools.islice( - child_events, max_children - ) - else: - child_iter = child_events - stripped_events: List[JsonDict] = [ { "type": e.type, @@ -682,80 +454,10 @@ class RoomSummaryHandler: "sender": e.sender, "origin_server_ts": e.origin_server_ts, } - for e in child_iter + for e in child_events ] return _RoomEntry(room_id, room_entry, stripped_events) - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - async def _summarize_remote_room_hierarchy( self, room: "_RoomQueueEntry", suggested_only: bool ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: @@ -958,9 +660,8 @@ class RoomSummaryHandler: ): return True - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + # Check if the user is a member of any of the allowed rooms from the response. + allowed_rooms = room.get("allowed_room_ids") if allowed_rooms and isinstance(allowed_rooms, list): if await self._event_auth_handler.is_user_in_rooms( allowed_rooms, requester @@ -1028,8 +729,6 @@ class RoomSummaryHandler: ) if allowed_rooms: entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} @@ -1094,7 +793,7 @@ class RoomSummaryHandler: room_id, # Suggested-only doesn't matter since no children are requested. suggested_only=False, - max_children=0, + include_children=False, ) if not room_entry: diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 727d75a50..9602f0d0b 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -52,7 +52,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 41cb80907..aa16e417e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,8 +14,9 @@ import itertools import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership @@ -32,9 +33,23 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _SearchResult: + # The count of results. + count: int + # A mapping of event ID to the rank of that event. + rank_map: Dict[str, int] + # A list of the resulting events. + allowed_events: List[EventBase] + # A map of room ID to results. + room_groups: Dict[str, JsonDict] + # A set of event IDs to highlight. + highlights: Set[str] + + class SearchHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.clock = hs.get_clock() self.hs = hs @@ -100,7 +115,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user + user: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -156,6 +171,8 @@ class SearchHandler: # Include context around each event? event_context = room_cat.get("event_context", None) + before_limit = after_limit = None + include_profile = False # Group results together? May allow clients to paginate within a # group @@ -182,6 +199,73 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) + return await self._search( + user, + batch_group, + batch_group_key, + batch_token, + search_term, + keys, + filter_dict, + order_by, + include_state, + group_keys, + event_context, + before_limit, + after_limit, + include_profile, + ) + + async def _search( + self, + user: UserID, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + search_term: str, + keys: List[str], + filter_dict: JsonDict, + order_by: str, + include_state: bool, + group_keys: List[str], + event_context: Optional[bool], + before_limit: Optional[int], + after_limit: Optional[int], + include_profile: bool, + ) -> JsonDict: + """Performs a full text search for a user. + + Args: + user: The user performing the search. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + filter_dict: The JSON to build a filter out of. + order_by: How to order the results. Valid values ore "rank" and "recent". + include_state: True if the state of the room at each result should + be included. + group_keys: A list of ways to group the results. Valid values are + "room_id" and "sender". + event_context: True to include contextual events around results. + before_limit: + The number of events before a result to include as context. + + Only used if event_context is True. + after_limit: + The number of events after a result to include as context. + + Only used if event_context is True. + include_profile: True if historical profile information should be + included in the event context. + + Only used if event_context is True. + + Returns: + dict to be returned to the client with results of search + """ search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too @@ -216,209 +300,57 @@ class SearchHandler: } } - rank_map = {} # event_id -> rank of event - allowed_events = [] - # Holds result of grouping by room, if applicable - room_groups: Dict[str, JsonDict] = {} - # Holds result of grouping by sender, if applicable - sender_group: Dict[str, JsonDict] = {} - - # Holds the next_batch for the entire result set if one of those exists - global_next_batch = None - - highlights = set() - - count = None + sender_group: Optional[Dict[str, JsonDict]] if order_by == "rank": - search_result = await self.store.search_msgs(room_ids, search_term, keys) - - count = search_result["count"] - - if search_result["highlights"]: - highlights.update(search_result["highlights"]) - - results = search_result["results"] - - rank_map.update({r["event"].event_id: r["rank"] for r in results}) - - filtered_events = await search_filter.filter([r["event"] for r in results]) - - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + search_result, sender_group = await self._search_by_rank( + user, room_ids, search_term, keys, search_filter ) - - events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit] - - for e in allowed_events: - rm = room_groups.setdefault( - e.room_id, {"results": [], "order": rank_map[e.event_id]} - ) - rm["results"].append(e.event_id) - - s = sender_group.setdefault( - e.sender, {"results": [], "order": rank_map[e.event_id]} - ) - s["results"].append(e.event_id) - + # Unused return values for rank search. + global_next_batch = None elif order_by == "recent": - room_events: List[EventBase] = [] - i = 0 - - pagination_token = batch_token - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit and i < 5: - i += 1 - search_result = await self.store.search_rooms( - room_ids, - search_term, - keys, - search_filter.limit * 2, - pagination_token=pagination_token, - ) - - if search_result["highlights"]: - highlights.update(search_result["highlights"]) - - count = search_result["count"] - - results = search_result["results"] - - results_map = {r["event"].event_id: r for r in results} - - rank_map.update({r["event"].event_id: r["rank"] for r in results}) - - filtered_events = await search_filter.filter( - [r["event"] for r in results] - ) - - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events - ) - - room_events.extend(events) - room_events = room_events[: search_filter.limit] - - if len(results) < search_filter.limit * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - for event in room_events: - group = room_groups.setdefault(event.room_id, {"results": []}) - group["results"].append(event.event_id) - - if room_events and len(room_events) >= search_filter.limit: - last_event_id = room_events[-1].event_id - pagination_token = results_map[last_event_id]["pagination_token"] - - # We want to respect the given batch group and group keys so - # that if people blindly use the top level `next_batch` token - # it returns more from the same group (if applicable) rather - # than reverting to searching all results again. - if batch_group and batch_group_key: - global_next_batch = encode_base64( - ( - "%s\n%s\n%s" - % (batch_group, batch_group_key, pagination_token) - ).encode("ascii") - ) - else: - global_next_batch = encode_base64( - ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") - ) - - for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64( - ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( - "ascii" - ) - ) - - allowed_events.extend(room_events) - + search_result, global_next_batch = await self._search_by_recent( + user, + room_ids, + search_term, + keys, + search_filter, + batch_group, + batch_group_key, + batch_token, + ) + # Unused return values for recent search. + sender_group = None else: # We should never get here due to the guard earlier. raise NotImplementedError() - logger.info("Found %d events to return", len(allowed_events)) + logger.info("Found %d events to return", len(search_result.allowed_events)) # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = self.hs.get_event_sources().get_current_token() + # Note that before and after limit must be set in this case. + assert before_limit is not None + assert after_limit is not None - contexts = {} - for event in allowed_events: - res = await self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit - ) - - logger.info( - "Context for search returned %d and %d events", - len(res.events_before), - len(res.events_after), - ) - - events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before - ) - - events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after - ) - - context = { - "events_before": events_before, - "events_after": events_after, - "start": await now_token.copy_and_replace( - "room_key", res.start - ).to_string(self.store), - "end": await now_token.copy_and_replace( - "room_key", res.end - ).to_string(self.store), - } - - if include_profile: - senders = { - ev.sender - for ev in itertools.chain(events_before, [event], events_after) - } - - if events_after: - last_event_id = events_after[-1].event_id - else: - last_event_id = event.event_id - - state_filter = StateFilter.from_types( - [(EventTypes.Member, sender) for sender in senders] - ) - - state = await self.state_store.get_state_for_event( - last_event_id, state_filter - ) - - context["profile_info"] = { - s.state_key: { - "displayname": s.content.get("displayname", None), - "avatar_url": s.content.get("avatar_url", None), - } - for s in state.values() - if s.type == EventTypes.Member and s.state_key in senders - } - - contexts[event.event_id] = context + contexts = await self._calculate_event_contexts( + user, + search_result.allowed_events, + before_limit, + after_limit, + include_profile, + ) else: contexts = {} # TODO: Add a limit - time_now = self.clock.time_msec() + state_results = {} + if include_state: + for room_id in {e.room_id for e in search_result.allowed_events}: + state = await self.state_handler.get_current_state(room_id) + state_results[room_id] = list(state.values()) aggregations = None if self._msc3666_enabled: @@ -432,11 +364,16 @@ class SearchHandler: for context in contexts.values() ), # The returned events. - allowed_events, + search_result.allowed_events, ), user.to_string(), ) + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise, the 'age' will be wrong. + + time_now = self.clock.time_msec() + for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] @@ -445,44 +382,33 @@ class SearchHandler: context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] ) - state_results = {} - if include_state: - for room_id in {e.room_id for e in allowed_events}: - state = await self.state_handler.get_current_state(room_id) - state_results[room_id] = list(state.values()) + results = [ + { + "rank": search_result.rank_map[e.event_id], + "result": self._event_serializer.serialize_event( + e, time_now, bundle_aggregations=aggregations + ), + "context": contexts.get(e.event_id, {}), + } + for e in search_result.allowed_events + ] - # We're now about to serialize the events. We should not make any - # blocking calls after this. Otherwise the 'age' will be wrong - - results = [] - for e in allowed_events: - results.append( - { - "rank": rank_map[e.event_id], - "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations - ), - "context": contexts.get(e.event_id, {}), - } - ) - - rooms_cat_res = { + rooms_cat_res: JsonDict = { "results": results, - "count": count, - "highlights": list(highlights), + "count": search_result.count, + "highlights": list(search_result.highlights), } if state_results: - s = {} - for room_id, state_events in state_results.items(): - s[room_id] = self._event_serializer.serialize_events( - state_events, time_now - ) + rooms_cat_res["state"] = { + room_id: self._event_serializer.serialize_events(state_events, time_now) + for room_id, state_events in state_results.items() + } - rooms_cat_res["state"] = s - - if room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + if search_result.room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})[ + "room_id" + ] = search_result.room_groups if sender_group and "sender" in group_keys: rooms_cat_res.setdefault("groups", {})["sender"] = sender_group @@ -491,3 +417,282 @@ class SearchHandler: rooms_cat_res["next_batch"] = global_next_batch return {"search_categories": {"room_events": rooms_cat_res}} + + async def _search_by_rank( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + ) -> Tuple[_SearchResult, Dict[str, JsonDict]]: + """ + Performs a full text search for a user ordering by rank. + + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + + Returns: + A tuple of: + The search results. + A map of sender ID to results. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + # Holds result of grouping by sender, if applicable + sender_group: Dict[str, JsonDict] = {} + + search_result = await self.store.search_msgs(room_ids, search_term, keys) + + if search_result["highlights"]: + highlights = search_result["highlights"] + else: + highlights = set() + + results = search_result["results"] + + # event_id -> rank of event + rank_map = {r["event"].event_id: r["rank"] for r in results} + + filtered_events = await search_filter.filter([r["event"] for r in results]) + + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[: search_filter.limit] + + for e in allowed_events: + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) + rm["results"].append(e.event_id) + + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) + s["results"].append(e.event_id) + + return ( + _SearchResult( + search_result["count"], + rank_map, + allowed_events, + room_groups, + highlights, + ), + sender_group, + ) + + async def _search_by_recent( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + ) -> Tuple[_SearchResult, Optional[str]]: + """ + Performs a full text search for a user ordering by recent. + + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + + Returns: + A tuple of: + The search results. + Optionally, a pagination token. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + + # Holds the next_batch for the entire result set if one of those exists + global_next_batch = None + + highlights = set() + + room_events: List[EventBase] = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit and i < 5: + i += 1 + search_result = await self.store.search_rooms( + room_ids, + search_term, + keys, + search_filter.limit * 2, + pagination_token=pagination_token, + ) + + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + count = search_result["count"] + + results = search_result["results"] + + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = await search_filter.filter([r["event"] for r in results]) + + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[: search_filter.limit] + + if len(results) < search_filter.limit * 2: + break + else: + pagination_token = results[-1]["pagination_token"] + + for event in room_events: + group = room_groups.setdefault(event.room_id, {"results": []}) + group["results"].append(event.event_id) + + if room_events and len(room_events) >= search_filter.limit: + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] + + # We want to respect the given batch group and group keys so + # that if people blindly use the top level `next_batch` token + # it returns more from the same group (if applicable) rather + # than reverting to searching all results again. + if batch_group and batch_group_key: + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) + else: + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) + + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) + + return ( + _SearchResult(count, rank_map, room_events, room_groups, highlights), + global_next_batch, + ) + + async def _calculate_event_contexts( + self, + user: UserID, + allowed_events: List[EventBase], + before_limit: int, + after_limit: int, + include_profile: bool, + ) -> Dict[str, JsonDict]: + """ + Calculates the contextual events for any search results. + + Args: + user: The user performing the search. + allowed_events: The search results. + before_limit: + The number of events before a result to include as context. + after_limit: + The number of events after a result to include as context. + include_profile: True if historical profile information should be + included in the event context. + + Returns: + A map of event ID to contextual information. + """ + now_token = self.hs.get_event_sources().get_current_token() + + contexts = {} + for event in allowed_events: + res = await self.store.get_events_around( + event.room_id, event.event_id, before_limit, after_limit + ) + + logger.info( + "Context for search returned %d and %d events", + len(res.events_before), + len(res.events_after), + ) + + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before + ) + + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after + ) + + context: JsonDict = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace("room_key", res.end).to_string( + self.store + ), + } + + if include_profile: + senders = { + ev.sender + for ev in itertools.chain(events_before, [event], events_after) + } + + if events_after: + last_event_id = events_after[-1].event_id + else: + last_event_id = event.event_id + + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] + ) + + state = await self.state_store.get_state_for_event( + last_event_id, state_filter + ) + + context["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } + + contexts[event.event_id] = context + + return contexts diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 706ad7276..73861bbd4 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 0bb8b0929..ff5b5169c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -180,7 +180,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index d30ba2b72..2d197282e 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -30,7 +30,7 @@ class MatchChange(Enum): class StateDeltasHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _get_key_change( self, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 29e41a4c7..436cd971c 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -39,7 +39,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0934f9336..1c117852e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -266,7 +266,7 @@ class SyncResult: class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() @@ -697,6 +697,15 @@ class SyncHandler: else: # no events in this room - so presumably no state state = {} + + # (erikj) This should be rarely hit, but we've had some reports that + # we get more state down gappy syncs than we should, so let's add + # some logging. + logger.info( + "Failed to find any events in room %s at %s", + room_id, + stream_position.room_key, + ) return state async def compute_summary( @@ -1288,23 +1297,54 @@ class SyncHandler: # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id + users_that_have_changed = set() + + joined_rooms = sync_result_builder.joined_room_ids + + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + changed_users = self.store.get_cached_device_list_changes( + since_token.device_list_key ) + if changed_users is not None: + result = await self.store.get_rooms_for_users_with_stream_ordering( + changed_users + ) - # Always tell the user about their own devices. We check as the user - # ID is almost certainly already included (unless they're not in any - # rooms) and taking a copy of the set is relatively expensive. - if user_id not in users_who_share_room: - users_who_share_room = set(users_who_share_room) - users_who_share_room.add(user_id) + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + e.room_id in joined_rooms for e in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_who_share_room = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) - tracked_users = users_who_share_room + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) - # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, tracked_users - ) + tracked_users = users_who_share_room + users_that_have_changed = ( + await self.store.get_users_whose_devices_changed( + since_token.device_list_key, tracked_users + ) + ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: @@ -1328,7 +1368,14 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_who_share_room + left_users_rooms = ( + await self.store.get_rooms_for_users_with_stream_ordering( + newly_left_users + ) + ) + for user_id, entries in left_users_rooms.items(): + if any(e.room_id in joined_rooms for e in entries): + newly_left_users.discard(user_id) return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e4bed1c93..843c68eb0 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -57,7 +57,7 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self._main_store = hs.get_datastore() + self._main_store = hs.get_datastores().main self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 184730ebe..014754a63 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -139,7 +139,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: @@ -255,7 +255,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): super().__init__(hs) self.hs = hs self._enabled = bool(hs.config.registration.registration_requires_token) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def is_enabled(self) -> bool: return self._enabled diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1565e034c..d27ed2be6 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c5f8fcbb2..40bf1e06d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -351,7 +351,7 @@ class MatrixFederationHttpClient: ) self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 @@ -958,6 +958,7 @@ class MatrixFederationHttpClient: ) return body + @overload async def get_json( self, destination: str, @@ -967,7 +968,38 @@ class MatrixFederationHttpClient: timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + max_response_size: Optional[int] = ..., + ) -> T: + ... + + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + max_response_size: Optional[int] = None, + ): """GETs some json from the given host homeserver and path Args: @@ -992,6 +1024,13 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + max_response_size: The maximum size to read from the response. If None, + uses the default. + Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout + if parser is None: + parser = JsonParser() + body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, + max_response_size=max_response_size, ) return body diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py deleted file mode 100644 index b9933a152..000000000 --- a/synapse/logging/_structured.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os.path -from typing import Any, Dict, Generator, Optional, Tuple - -from constantly import NamedConstant, Names - -from synapse.config._base import ConfigError - - -class DrainType(Names): - CONSOLE = NamedConstant() - CONSOLE_JSON = NamedConstant() - CONSOLE_JSON_TERSE = NamedConstant() - FILE = NamedConstant() - FILE_JSON = NamedConstant() - NETWORK_JSON_TERSE = NamedConstant() - - -DEFAULT_LOGGERS = {"synapse": {"level": "info"}} - - -def parse_drain_configs( - drains: dict, -) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """ - Parse the drain configurations. - - Args: - drains (dict): A list of drain configurations. - - Yields: - dict instances representing a logging handler. - - Raises: - ConfigError: If any of the drain configuration items are invalid. - """ - - for name, config in drains.items(): - if "type" not in config: - raise ConfigError("Logging drains require a 'type' key.") - - try: - logging_type = DrainType.lookupByName(config["type"].upper()) - except ValueError: - raise ConfigError( - "%s is not a known logging drain type." % (config["type"],) - ) - - # Either use the default formatter or the tersejson one. - if logging_type in ( - DrainType.CONSOLE_JSON, - DrainType.FILE_JSON, - ): - formatter: Optional[str] = "json" - elif logging_type in ( - DrainType.CONSOLE_JSON_TERSE, - DrainType.NETWORK_JSON_TERSE, - ): - formatter = "tersejson" - else: - # A formatter of None implies using the default formatter. - formatter = None - - if logging_type in [ - DrainType.CONSOLE, - DrainType.CONSOLE_JSON, - DrainType.CONSOLE_JSON_TERSE, - ]: - location = config.get("location") - if location is None or location not in ["stdout", "stderr"]: - raise ConfigError( - ( - "The %s drain needs the 'location' key set to " - "either 'stdout' or 'stderr'." - ) - % (logging_type,) - ) - - yield name, { - "class": "logging.StreamHandler", - "formatter": formatter, - "stream": "ext://sys." + location, - } - - elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: - if "location" not in config: - raise ConfigError( - "The %s drain needs the 'location' key set." % (logging_type,) - ) - - location = config.get("location") - if os.path.abspath(location) != location: - raise ConfigError( - "File paths need to be absolute, '%s' is a relative path" - % (location,) - ) - - yield name, { - "class": "logging.FileHandler", - "formatter": formatter, - "filename": location, - } - - elif logging_type in [DrainType.NETWORK_JSON_TERSE]: - host = config.get("host") - port = config.get("port") - maximum_buffer = config.get("maximum_buffer", 1000) - - yield name, { - "class": "synapse.logging.RemoteHandler", - "formatter": formatter, - "host": host, - "port": port, - "maximum_buffer": maximum_buffer, - } - - else: - raise ConfigError( - "The %s drain type is currently not implemented." - % (config["type"].upper(),) - ) - - -def setup_structured_logging( - log_config: dict, -) -> dict: - """ - Convert a legacy structured logging configuration (from Synapse < v1.23.0) - to one compatible with the new standard library handlers. - """ - if "drains" not in log_config: - raise ConfigError("The logging configuration requires a list of drains.") - - new_config = { - "version": 1, - "formatters": { - "json": {"class": "synapse.logging.JsonFormatter"}, - "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, - }, - "handlers": {}, - "loggers": log_config.get("loggers", DEFAULT_LOGGERS), - "root": {"handlers": []}, - } - - for handler_name, handler in parse_drain_configs(log_config["drains"]): - new_config["handlers"][handler_name] = handler - - # Add each handler to the root logger. - new_config["root"]["handlers"].append(handler_name) - - return new_config diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d4fca3692..7e4693186 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -70,6 +70,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -144,6 +145,7 @@ __all__ = [ "JsonDict", "EventBase", "StateMap", + "ProfileInfo", ] logger = logging.getLogger(__name__) @@ -171,7 +173,9 @@ class ModuleApi: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() + self._store: Union[ + DataStore, "GenericWorkerSlavedStore" + ] = hs.get_datastores().main self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -317,6 +321,9 @@ class ModuleApi: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +335,7 @@ class ModuleApi: is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( @@ -648,7 +656,11 @@ class ModuleApi: Added in Synapse v1.9.0. Args: - auth_provider: identifier for the remote auth provider + auth_provider: identifier for the remote auth provider, see `sso` and + `oidc_providers` in the homeserver configuration. + + Note that no error is raised if the provided value is not in the + homeserver configuration. external_id: id on that system user_id: complete mxid that it is mapped to """ @@ -917,7 +929,7 @@ class ModuleApi: ) # Try to retrieve the resulting event. - event = await self._hs.get_datastore().get_event(event_id) + event = await self._hs.get_datastores().main.get_event(event_id) # update_membership is supposed to always return after the event has been # successfully persisted. @@ -1261,7 +1273,7 @@ class PublicRoomListManager: """ def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. diff --git a/synapse/notifier.py b/synapse/notifier.py index e0fad2da6..16d15a1f3 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -138,7 +138,7 @@ class _NotifierUserStream: self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms - noify_deferred = self.notify_deferred + notify_deferred = self.notify_deferred log_kv( { @@ -153,7 +153,7 @@ class _NotifierUserStream: with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - noify_deferred.callback(self.current_token) + notify_deferred.callback(self.current_token) def remove(self, notifier: "Notifier") -> None: """Remove this listener from all the indexes in the Notifier @@ -222,7 +222,7 @@ class Notifier: self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] # Called when there are new things to stream over replication diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5176a1c18..a1b771109 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -68,7 +68,7 @@ class ThrottleParams: class Pusher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() self.pusher_id = pusher_config.id diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 910b05c0d..832eaa34e 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -130,7 +130,9 @@ def make_base_prepend_rules( return rules -BASE_APPEND_CONTENT_RULES = [ +# We have to annotate these types, otherwise mypy infers them as +# `List[Dict[str, Sequence[Collection[str]]]]`. +BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/content/.m.rule.contains_user_name", "conditions": [ @@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [ ] -BASE_PREPEND_OVERRIDE_RULES = [ +BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.master", "enabled": False, @@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_OVERRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.suppress_notices", "conditions": [ @@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_UNDERRIDE_RULES = [ +BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/underride/.m.rule.call", "conditions": [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bee660893..fecf86034 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -103,7 +103,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -366,7 +366,7 @@ class RulesForRoom: """ self.room_id = room_id self.is_mine_id = hs.is_mine_id - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_push_rule_cache_metrics = room_push_rule_cache_metrics # Used to ensure only one thing mutates the cache at a time. Keyed off diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 39bb2acae..1710dd51b 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -66,7 +66,7 @@ class EmailPusher(Pusher): super().__init__(hs, pusher_config) self.mailer = mailer - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None self.throttle_params: Dict[str, ThrottleParams] = {} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 4dd3789c8..fda7c575d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -102,6 +102,7 @@ class HttpPusher(Pusher): self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] + self.badge_count_last_call: Optional[int] = None def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. @@ -125,11 +126,13 @@ class HttpPusher(Pusher): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) - await self._send_badge(badge) + if self.badge_count_last_call is None or self.badge_count_last_call != badge: + self.badge_count_last_call = badge + await self._send_badge(badge) def on_timer(self) -> None: self._start_processing() @@ -273,7 +276,7 @@ class HttpPusher(Pusher): tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( - self.hs.get_datastore(), + self.hs.get_datastores().main, self.user_id, group_by_room=self._group_unread_count_by_room, ) @@ -315,7 +318,7 @@ class HttpPusher(Pusher): # This was checked in the __init__, but mypy doesn't seem to know that. assert self.data is not None if self.data.get("format") == "event_id_only": - d = { + d: Dict[str, Any] = { "notification": { "event_id": event.event_id, "room_id": event.room_id, @@ -395,6 +398,8 @@ class HttpPusher(Pusher): rejected = [] if "rejected" in resp: rejected = resp["rejected"] + else: + self.badge_count_last_call = badge return rejected async def _send_badge(self, badge: int) -> None: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3df8452ee..649a4f49d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -112,7 +112,7 @@ class Mailer: self.template_text = template_text self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 659a53805..f617c759e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,12 +15,12 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern from synapse.events import EventBase -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -223,7 +223,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, JsonDict], + d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -234,7 +234,7 @@ def _flatten_dict( for key, value in d.items(): if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, dict): + elif isinstance(value, Mapping): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 7912311d2..d0cc657b4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -59,7 +59,7 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() # We shard the handling of push notifications by user ID. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 86162e0f2..8f48a3393 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -17,14 +17,7 @@ import itertools import logging -from typing import List, Set - -from pkg_resources import ( - DistributionNotFound, - Requirement, - VersionConflict, - get_provider, -) +from typing import Set logger = logging.getLogger(__name__) @@ -87,8 +80,11 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.1", + # ijson 3.1.4 fixes a bug with "." in property names + "ijson>=3.1.4", "matrix-common~=1.1.0", + # For runtime introspection of our dependencies + "packaging~=21.3", ] CONDITIONAL_REQUIREMENTS = { @@ -143,102 +139,6 @@ def list_requirements(): return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS) -class DependencyException(Exception): - @property - def message(self): - return "\n".join( - [ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ] - ) - - @property - def dependencies(self): - for i in self.args[0]: - yield '"' + i + '"' - - -def check_requirements(for_feature=None): - deps_needed = [] - errors = [] - - if for_feature: - reqs = CONDITIONAL_REQUIREMENTS[for_feature] - else: - reqs = REQUIREMENTS - - for dependency in reqs: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - deps_needed.append(dependency) - if for_feature: - errors.append( - "Needed %s for the '%s' feature but it was not installed" - % (dependency, for_feature) - ) - else: - errors.append("Needed %s but it was not installed" % (dependency,)) - - if not for_feature: - # Check the optional dependencies are up to date. We allow them to not be - # installed. - OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) - - for dependency in OPTS: - try: - _check_requirement(dependency) - except VersionConflict as e: - deps_needed.append(dependency) - errors.append( - "Needed optional %s, got %s==%s" - % ( - dependency, - e.dist.project_name, # type: ignore[attr-defined] # noqa - e.dist.version, # type: ignore[attr-defined] # noqa - ) - ) - except DistributionNotFound: - # If it's not found, we don't care - pass - - if deps_needed: - for err in errors: - logging.error(err) - - raise DependencyException(deps_needed) - - -def _check_requirement(dependency_string): - """Parses a dependency string, and checks if the specified requirement is installed - - Raises: - VersionConflict if the requirement is installed, but with the the wrong version - DistributionNotFound if nothing is found to provide the requirement - """ - req = Requirement.parse(dependency_string) - - # first check if the markers specify that this requirement needs installing - if req.marker is not None and not req.marker.evaluate(): - # not required for this environment - return - - get_provider(req) - - if __name__ == "__main__": import sys diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index bc1d28dd1..2e697c74a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -268,7 +268,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): raise e.to_synapse_error() except Exception as e: _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e + raise SynapseError( + 502, f"Failed to talk to {instance_name} process" + ) from e _outgoing_request_counter.labels(cls.NAME, 200).inc() return result diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f2f40129f..3d6364572 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -63,7 +63,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index d529c8a19..3e7300b4a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.registry = hs.get_federation_registry() @@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] @@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @staticmethod async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0145858e4..663bff573 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -50,7 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -119,7 +119,7 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): super().__init__(hs) self.federation_handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() @staticmethod @@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @@ -325,7 +325,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.distributor = hs.get_distributor() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index c7f751b70..6c8f8388f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod @@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 33e98daf8..ce7817683 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -69,7 +69,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d59ce7ccf..1b8479b0b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -111,7 +111,7 @@ class ReplicationDataHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._reactor = hs.get_reactor() self._clock = hs.get_clock() @@ -340,7 +340,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 17e157239..0d2013a3c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -95,7 +95,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ecd6190f5..494e42a2b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -72,7 +72,7 @@ class ReplicationStreamer: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 914b9eae8..23d631a76 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -239,7 +239,7 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._current_token, @@ -267,7 +267,7 @@ class PresenceStream(Stream): ROW_TYPE = PresenceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main if hs.get_instance_name() in hs.config.worker.writers.presence: # on the presence writer, query the presence handler @@ -355,7 +355,7 @@ class ReceiptsStream(Stream): ROW_TYPE = ReceiptsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), @@ -374,7 +374,7 @@ class PushRulesStream(Stream): ROW_TYPE = PushRulesStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -401,7 +401,7 @@ class PushersStream(Stream): ROW_TYPE = PushersStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -434,7 +434,7 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), store.get_cache_stream_token_for_writer, @@ -455,7 +455,7 @@ class DeviceListsStream(Stream): ROW_TYPE = DeviceListsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), @@ -474,7 +474,7 @@ class ToDeviceStream(Stream): ROW_TYPE = ToDeviceStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), @@ -516,7 +516,7 @@ class AccountDataStream(Stream): ROW_TYPE = AccountDataStreamRow def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), @@ -585,7 +585,7 @@ class GroupServerStream(Stream): ROW_TYPE = GroupsStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), @@ -604,7 +604,7 @@ class UserSignatureStream(Stream): ROW_TYPE = UserSignatureStreamRow def __init__(self, hs: "HomeServer"): - store = hs.get_datastore() + store = hs.get_datastores().main super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 50c4a5ba0..26f4fa7cf 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -124,7 +124,7 @@ class EventsStream(Stream): NAME = "events" def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main super().__init__( hs.get_instance_name(), self._store._stream_id_gen.get_current_token_for_writer, diff --git a/synapse/res/providers.json b/synapse/res/providers.json index f1838f955..7b9958e45 100644 --- a/synapse/res/providers.json +++ b/synapse/res/providers.json @@ -5,8 +5,6 @@ "endpoints": [ { "schemes": [ - "https://twitter.com/*/status/*", - "https://*.twitter.com/*/status/*", "https://twitter.com/*/moments/*", "https://*.twitter.com/*/moments/*" ], @@ -14,4 +12,4 @@ } ] } -] \ No newline at end of file +] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba0d989d8..6de302f81 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index e9bce22a3..93a78db81 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d9905ff56..cef46ba0d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -44,7 +44,7 @@ class DeviceRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_GET( @@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine = hs.is_mine async def on_POST( diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 38477f8ea..6d634eef7 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index d162e0081..023ed9214 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) @@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, destination: str @@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._authenticator = Authenticator(hs) async def on_POST( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 299f5c9eb..8ca57bdb2 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet): ) def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet): PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_POST( @@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet): PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() @@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repository = hs.get_media_repository() async def on_GET( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 04948b640..af606e925 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token self.allowed_chars = string.ascii_letters + string.digits + "._~-" @@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Retrieve a registration token.""" diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b706efbc..f4736a3da 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -65,7 +65,7 @@ class RoomRestV2Servlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() async def on_DELETE( @@ -188,7 +188,7 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -278,7 +278,7 @@ class RoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -382,7 +382,7 @@ class RoomMembersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -408,7 +408,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -525,7 +525,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -670,7 +670,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_DELETE( self, request: SynapseRequest, room_identifier: str @@ -781,7 +781,7 @@ class BlockRoomRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 7a6546372..3b142b840 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c2617ee30..8e29ada8a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -156,7 +156,7 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() @@ -588,7 +588,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, target_user_id: str @@ -674,7 +674,7 @@ class ResetPasswordRestServlet(RestServlet): PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -717,7 +717,7 @@ class SearchUsersRestServlet(RestServlet): PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -775,7 +775,7 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine = hs.is_mine @@ -835,7 +835,7 @@ class UserMembershipRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str @@ -864,7 +864,7 @@ class PushersRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET( @@ -905,7 +905,7 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.is_mine_id = hs.is_mine_id @@ -974,7 +974,7 @@ class ShadowBanRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1026,7 +1026,7 @@ class RateLimitRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id @@ -1129,7 +1129,7 @@ class AccountDataRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id async def on_GET( diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index cfa2aee76..5587cae98 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -114,7 +114,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -168,7 +168,7 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @@ -347,7 +347,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.hs = hs self.config = hs.config self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -533,7 +533,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html @@ -600,7 +600,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): super().__init__() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: @@ -634,7 +634,7 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -768,7 +768,7 @@ class ThreepidUnbindRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() + self.datastore = self.hs.get_datastores().main async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are @@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet): response = { "user_id": requester.user.to_string(), # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + # Entered spec in Matrix 1.2 "org.matrix.msc3069.is_guest": bool(requester.is_guest), + "is_guest": bool(requester.is_guest), } # Appservices and similar accounts do not have device IDs @@ -894,6 +896,33 @@ class WhoamiRestServlet(RestServlet): return 200, response +class AccountStatusRestServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3720/account_status$", unstable=True, releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._account_handler = hs.get_account_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self._auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + if "user_ids" not in body: + raise SynapseError( + 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM + ) + + statuses, failures = await self._account_handler.get_account_statuses( + body["user_ids"], + allow_remote=True, + ) + + return 200, {"account_statuses": statuses, "failures": failures} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) @@ -908,3 +937,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3720_enabled: + AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 58b8adbd3..bfe985939 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( @@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() async def on_PUT( diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 9c15a0433..e0b2b80e5 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, ) @@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet): # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, @@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), error=e.msg, ) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 6682da077..4237071c6 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -72,22 +72,10 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - # Must be removed in later versions. - # Is only included for migration. - # Also the parts in `synapse/config/experimental.py`. - if self.config.experimental.msc3283_enabled: - response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.registration.enable_set_displayname + if self.config.experimental.msc3720_enabled: + response["capabilities"]["org.matrix.msc3720.account_status"] = { + "enabled": True, } - response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.registration.enable_set_avatar_url - } - response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.registration.enable_3pid_changes - } - - if self.config.experimental.msc3440_enabled: - response["capabilities"]["io.element.thread"] = {"enabled": True} return HTTPStatus.OK, response diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ee247e3d1..e181a0dde 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() @@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 672c82106..916f5230f 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -39,7 +39,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index a7e9aa3e9..7e1149c7f 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -705,7 +705,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id @_validate_group_id @@ -854,7 +854,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main @_validate_group_id async def on_PUT( @@ -879,7 +879,7 @@ class PublicisedGroupsForUserServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_GET( @@ -901,7 +901,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.groups_handler = hs.get_groups_local_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 49b1037b2..cfadcb8e5 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -33,7 +33,7 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 730c18f08..ce806e3c1 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -198,7 +198,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c..c9d44c596 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,13 +104,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), + store=hs.get_datastores().main, clock=hs.get_clock(), rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 8e427a96a..20377a9ac 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index add56d699..820682ec4 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8fe75bd75..a93f6fd5e 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 98604a938..d6487c31d 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -46,7 +46,9 @@ class PushersRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( + user.to_string() + ) filtered_pushers = [p.as_dict() for p in pushers] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c965e2bda..70baf50fa 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -123,7 +123,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): request, "email", email ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "email", email ) @@ -203,7 +203,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): request, "msisdn", msisdn ) - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( "msisdn", msisdn ) @@ -258,7 +258,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.auth = hs.get_auth() self.config = hs.config self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( @@ -385,7 +385,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -415,7 +415,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.identity_handler = hs.get_identity_handler() @@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet): session_id ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 2cab83c4e..487ea38b5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() @@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() async def on_GET( @@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index d4a4adb50..6e962a453 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -38,7 +38,7 @@ class ReportEventRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index d31e58b1d..1d007fcbd 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -481,7 +481,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -557,7 +557,7 @@ class RoomMessageListRestServlet(RestServlet): self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -625,7 +625,7 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, room_id: str @@ -646,7 +646,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -1031,7 +1031,7 @@ class JoinedRoomsRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -1120,7 +1120,7 @@ class TimestampLookupRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() async def on_GET( @@ -1145,73 +1145,6 @@ class TimestampLookupRestServlet(RestServlet): } -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1305,7 +1238,6 @@ def register_servlets( RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) if hs.config.experimental.msc3266_enabled: RoomSummaryRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 4b6be3832..0048973e5 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 09a46737d..e669fa789 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_directory_active = hs.config.server.update_user_directory async def on_GET( diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f9615da52..f3018ff69 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c88cb9367..ca638755c 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -39,7 +39,7 @@ class TagListServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, user_id: str, room_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2290c57c1..2e5d0e4e2 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet): "r0.5.0", "r0.6.0", "r0.6.1", + "v1.1", + "v1.2", ], # as per MSC1497: "unstable_features": { @@ -97,6 +99,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 "org.matrix.msc3030": self.config.experimental.msc3030_enabled, + # Adds support for thread relations, per MSC3440. + "org.matrix.msc3440": self.config.experimental.msc3440_enabled, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 3d2afacc5..25f9ea285 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -78,7 +78,7 @@ class ConsentResource(DirectServeHtmlResource): super().__init__() self.hs = hs - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() # this is required by the request_handler wrapper diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 3923ba843..3525d6ae5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -94,7 +94,7 @@ class RemoteKey(DirectServeJsonResource): super().__init__() self.fetcher = ServerKeyFetcher(hs) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 71b9a34b1..6c414402b 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -75,7 +75,7 @@ class MediaRepository: self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0ff12a16a..d5576ea0c 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -133,13 +133,13 @@ class PreviewUrlResource(DirectServeJsonResource): self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.media.url_preview_ip_range_whitelist, ip_blacklist=hs.config.media.url_preview_ip_range_blacklist, - user_agent=f"{hs.version_string} UrlPreviewBot", + user_agent=f"Synapse (bot; +https://github.com/matrix-org/synapse)", use_proxy=True, ) self.media_repo = media_repo @@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource): url, output_stream=output_stream, max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, + headers={ + b"Accept-Language": self.url_preview_accept_language, + # Use a custom user agent for the preview because some sites will only return + # Open Graph metadata to crawler user agents. Omit the Synapse version + # string to avoid leaking information. + b"User-Agent": [ + "Synapse (bot; +https://github.com/matrix-org/synapse)" + ], + }, is_allowed_content_type=_is_previewable, ) except SynapseError: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ed91ef5a4..53b156524 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -50,7 +50,7 @@ class ThumbnailResource(DirectServeJsonResource): ): super().__init__() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index fde28d08c..e73e431dc 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -37,7 +37,7 @@ class UploadResource(DirectServeJsonResource): self.media_repo = media_repo self.filepaths = media_repo.filepaths - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 28a67f04e..6ac9dbc7c 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -44,7 +44,7 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): super().__init__() self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._local_threepid_handling_disabled_due_to_email_config = ( hs.config.email.local_threepid_handling_disabled_due_to_email_config diff --git a/synapse/server.py b/synapse/server.py index 564afdcb9..b5e2a319b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -17,7 +17,7 @@ # homeservers; either as a full homeserver as a real application, or a small # partial one for unit test mocking. -# Imports required for the default HomeServer() implementation + import abc import functools import logging @@ -62,6 +62,7 @@ from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler @@ -133,7 +134,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, DataStore, Storage +from synapse.storage import Databases, Storage from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -224,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta): # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be - # instantiated during setup() for future return by get_datastore() + # instantiated during setup() for future return by get_datastores() DATASTORE_CLASS = abc.abstractproperty() tls_server_context_factory: Optional[IOpenSSLContextFactory] @@ -354,12 +355,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_clock(self) -> Clock: return Clock(self._reactor) - def get_datastore(self) -> DataStore: - if not self.datastores: - raise Exception("HomeServer.setup must be called before getting datastores") - - return self.datastores.main - def get_datastores(self) -> Databases: if not self.datastores: raise Exception("HomeServer.setup must be called before getting datastores") @@ -373,7 +368,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( - store=self.get_datastore(), + store=self.get_datastores().main, clock=self.get_clock(), rate_hz=self.config.ratelimiting.rc_registration.per_second, burst_count=self.config.ratelimiting.rc_registration.burst_count, @@ -807,6 +802,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_external_cache(self) -> ExternalCache: return ExternalCache(self) + @cache_in_self + def get_account_handler(self) -> AccountHandler: + return AccountHandler(self) + @cache_in_self def get_outbound_redis_connection(self) -> "RedisProtocol": """ @@ -842,7 +841,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_request_ratelimiter(self) -> RequestRatelimiter: return RequestRatelimiter( - self.get_datastore(), + self.get_datastores().main, self.get_clock(), self.config.ratelimiting.rc_message, self.config.ratelimiting.rc_admin_redaction, diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e09a25591..698ca742e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -32,7 +32,7 @@ class ConsentServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._users_in_progress: Set[str] = set() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 8522930b5..015dd08f0 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,7 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 0cf60236f..7b4814e04 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -29,7 +29,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 67e8bc6ec..6babd5963 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -126,7 +126,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -258,7 +258,10 @@ class StateHandler: return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ class StateHandler: calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ class StateHandler: else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index cfe887b7f..ce3d1d4e9 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -24,6 +24,7 @@ from synapse.storage.prepare_database import prepare_database if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]): """ databases: List[DatabasePool] - main: DataStoreT + main: "DataStore" # FIXME: #11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 304814af5..069444655 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,14 +20,18 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,7 +60,7 @@ def _make_exclusive_regex( return exclusive_user_pattern -class ApplicationServiceWorkerStore(SQLBaseStore): +class ApplicationServiceWorkerStore(RoomMemberWorkerStore): def __init__( self, database: DatabasePool, @@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return service return None + @cached(iterable=True, cache_context=True) + async def get_app_service_users_in_room( + self, + room_id: str, + app_service: "ApplicationService", + cache_context: _CacheContext, + ) -> List[str]: + users_in_room = await self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + return list(filter(app_service.is_interested_in_user, users_in_room)) + class ApplicationServiceStore(ApplicationServiceWorkerStore): # This is currently empty due to there not being any AS storage functions @@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], + one_time_key_counts: TransactionOneTimeKeyCounts, + unused_fallback_keys: TransactionUnusedFallbackKeys, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. + one_time_key_counts: Counts of remaining one-time keys for relevant + appservice devices in the transaction. + unused_fallback_keys: Lists of unused fallback keys for relevant + appservice devices in the transaction. Returns: A new transaction. @@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, + one_time_key_counts=one_time_key_counts, + unused_fallback_keys=unused_fallback_keys, ) return await self.db_pool.runInteraction( @@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) + # TODO: to-device messages, one-time key counts and unused fallback keys + # are not yet populated for catch-up transactions. + # We likely want to populate those for reliability. return AppServiceTransaction( service=service, id=entry["txn_id"], events=events, ephemeral=[], to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8d845fe95..3b3a089b7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } + def get_cached_device_list_changes( + self, + from_key: int, + ) -> Optional[Set[str]]: + """Get set of users whose devices have changed since `from_key`, or None + if that information is not in our cache. + """ + + return self._device_list_stream_cache.get_all_entities_changed(from_key) + async def get_users_whose_devices_changed( self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1f8447b50..9b293475c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -29,6 +29,10 @@ import attr from canonicaljson import encode_canonical_json from synapse.api.constants import DeviceKeyAlgorithms +from synapse.appservice import ( + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def count_bulk_e2e_one_time_keys_for_as( + self, user_ids: Collection[str] + ) -> TransactionOneTimeKeyCounts: + """ + Counts, in bulk, the one-time keys for all the users specified. + Intended to be used by application services for populating OTK counts in + transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithm -> count + Empty algorithm -> count dicts are created if needed to represent a + lack of unused one-time keys. + """ + + def _count_bulk_e2e_one_time_keys_txn( + txn: LoggingTransaction, + ) -> TransactionOneTimeKeyCounts: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + sql = f""" + SELECT user_id, device_id, algorithm, COUNT(key_id) + FROM devices + LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) + WHERE {user_in_where_clause} + GROUP BY user_id, device_id, algorithm + """ + txn.execute(sql, user_parameters) + + result: TransactionOneTimeKeyCounts = {} + + for user_id, device_id, algorithm, count in txn: + # We deliberately construct empty dictionaries for + # users and devices without any unused one-time keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_count_by_algo = result.setdefault(user_id, {}).setdefault( + device_id, {} + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_count_by_algo[algorithm] = count + + return result + + return await self.db_pool.runInteraction( + "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn + ) + + async def get_e2e_bulk_unused_fallback_key_types( + self, user_ids: Collection[str] + ) -> TransactionUnusedFallbackKeys: + """ + Finds, in bulk, the types of unused fallback keys for all the users specified. + Intended to be used by application services for populating unused fallback + keys in transactions. + + Return structure is of the shape: + user_id -> device_id -> algorithms + Empty lists are created for devices if there are no unused fallback + keys. This matches the response structure of MSC3202. + """ + if len(user_ids) == 0: + return {} + + def _get_bulk_e2e_unused_fallback_keys_txn( + txn: LoggingTransaction, + ) -> TransactionUnusedFallbackKeys: + user_in_where_clause, user_parameters = make_in_list_sql_clause( + self.database_engine, "devices.user_id", user_ids + ) + # We can't use USING here because we require the `.used` condition + # to be part of the JOIN condition so that we generate empty lists + # when all keys are used (as opposed to just when there are no keys at all). + sql = f""" + SELECT devices.user_id, devices.device_id, algorithm + FROM devices + LEFT JOIN e2e_fallback_keys_json AS fallback_keys + ON devices.user_id = fallback_keys.user_id + AND devices.device_id = fallback_keys.device_id + AND NOT fallback_keys.used + WHERE + {user_in_where_clause} + """ + txn.execute(sql, user_parameters) + + result: TransactionUnusedFallbackKeys = {} + + for user_id, device_id, algorithm in txn: + # We deliberately construct empty dictionaries and lists for + # users and devices without any unused fallback keys. + # We *could* omit these empty dicts if there have been no + # changes since the last transaction, but we currently don't + # do any change tracking! + device_unused_keys = result.setdefault(user_id, {}).setdefault( + device_id, [] + ) + if algorithm is not None: + # algorithm will be None if this device has no keys. + device_unused_keys.append(algorithm) + + return result + + return await self.db_pool.runInteraction( + "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn + ) + async def set_e2e_fallback_keys( self, user_id: str, device_id: str, fallback_keys: JsonDict ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5246fccad..ca2a9ba9d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -130,7 +130,7 @@ class PersistEventsStore: *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], - new_forward_extremeties: Dict[str, List[str]], + new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -143,7 +143,7 @@ class PersistEventsStore: the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state - new_forward_extremities: Map from room_id to list of event IDs + new_forward_extremities: Map from room_id to set of event IDs that are the new forward extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True @@ -193,7 +193,7 @@ class PersistEventsStore: events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, ) persist_event_counter.inc(len(events_and_contexts)) @@ -220,7 +220,7 @@ class PersistEventsStore: for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremeties.items(): + for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -334,8 +334,8 @@ class PersistEventsStore: events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremeties: Optional[Dict[str, List[str]]] = None, - ): + new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + ) -> None: """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -353,13 +353,13 @@ class PersistEventsStore: from the database. This is useful when retrying due to IntegrityError. state_delta_for_room: The current-state delta for each room. - new_forward_extremetie: The new forward extremities for each room. + new_forward_extremities: The new forward extremities for each room. For each room, a list of the event ids which are the forward extremities. """ state_delta_for_room = state_delta_for_room or {} - new_forward_extremeties = new_forward_extremeties or {} + new_forward_extremities = new_forward_extremities or {} all_events_and_contexts = events_and_contexts @@ -372,7 +372,7 @@ class PersistEventsStore: self._update_forward_extremities_txn( txn, - new_forward_extremities=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, max_stream_order=max_stream_order, ) @@ -975,6 +975,17 @@ class PersistEventsStore: to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + if delta_state.no_longer_in_room: # Server is no longer in the room so we delete the room from # current_state_events, being careful we've already updated the @@ -993,6 +1004,11 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) + self.db_pool.simple_delete_txn( txn, table="current_state_events", @@ -1102,17 +1118,6 @@ class PersistEventsStore: # Invalidate the various caches - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } - for member in members_changed: txn.call_after( self.store.get_rooms_for_user_with_stream_ordering.invalidate, @@ -1153,7 +1158,10 @@ class PersistEventsStore: ) def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order + self, + txn: LoggingTransaction, + new_forward_extremities: Dict[str, Set[str]], + max_stream_order: int, ): for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( @@ -1468,10 +1476,10 @@ class PersistEventsStore: def _update_metadata_tables_txn( self, - txn, + txn: LoggingTransaction, *, - events_and_contexts, - all_events_and_contexts, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1948,20 +1956,20 @@ class PersistEventsStore: txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: + def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) @@ -2137,6 +2145,14 @@ class PersistEventsStore: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): + # double-check that we don't have any events that claim to be outliers + # *and* have partial state (which is meaningless: we should have no + # state at all for an outlier) + if context.partial_state: + raise ValueError( + "Outlier event %s claims to have partial state", event.event_id + ) + continue # if the event was rejected, just give it the same state as its @@ -2147,6 +2163,23 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group + # if we have partial state for these events, record the fact. (This happens + # here rather than in _store_event_txn because it also needs to happen when + # we de-outlier an event.) + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8d4287045..26784f755 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore): include the previous states content in the unsigned field. allow_rejected: If True, return rejected events. Otherwise, - omits rejeted events from the response. + omits rejected events from the response. Returns: A mapping from event_id to event. @@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore): forward_edge_query = """ SELECT 1 FROM event_edges /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE event_edges.room_id = ? AND event_edges.prev_event_id = ? @@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore): "get_event_id_for_timestamp_txn", get_event_id_for_timestamp_txn, ) + + @cachedList("is_partial_state_event", list_name="event_ids") + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + """Checks which of the given events have partial state""" + result = await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ) + # convert the result to a dict, to make @cachedList work + partial = {r["event_id"] for r in result} + return {e_id: e_id in partial for e_id in event_ids} + + @cached() + async def is_partial_state_event(self, event_id: str) -> bool: + """Checks if the given event has partial state""" + result = await self.db_pool.simple_select_one_onecol( + table="partial_state_events", + keyvalues={"event_id": event_id}, + retcol="1", + allow_none=True, + desc="is_partial_state_event", + ) + return result is not None diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 8f09dd8e8..e9a0cdc6b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): for tp in self.hs.config.server.mau_limits_reserved_threepids[ : self.hs.config.server.max_mau_value ]: - user_id = await self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( tp["medium"], canonicalise_email(tp["address"]) ) if user_id: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4f05811a7..d3c461168 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) # Used by `PresenceStore._get_active_presence()` @@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._presence_id_gen: AbstractStreamIdGenerator + self._can_persist_presence = ( - hs.get_instance_name() in hs.config.worker.writers.presence + self._instance_name in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): @@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() - def _update_presence_txn(self, txn, stream_orderings, presence_states): + def _update_presence_txn( + self, txn: LoggingTransaction, stream_orderings, presence_states + ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id @@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): if last_id == current_id: return [], current_id, False - def get_all_presence_updates_txn(txn): + def get_all_presence_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active + status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates = cast( + List[Tuple[int, list]], + [(row[0], row[1:]) for row in txn], + ) upper_bound = current_id limited = False @@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): ) @cached() - def _get_presence_for_user(self, user_id): + def _get_presence_for_user(self, user_id: str) -> None: raise NotImplementedError() @cachedList( @@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): list_name="user_ids", num_args=1, ) - async def get_presence_for_users(self, user_ids): + async def get_presence_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn(txn): + def _should_user_receive_full_presence_with_token_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? @@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): _should_user_receive_full_presence_with_token_txn, ) - async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: """Adds to the list of users who should receive a full snapshot of presence upon their next sync. @@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return users_to_state - def get_current_presence_token(self): + def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() - def _get_active_presence(self, db_conn: Connection): + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ @@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): return [UserPresenceState(**row) for row in rows] - def take_presence_startup_info(self): + def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup - self._presence_on_startup = None + self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index e87a8fb85..2e3818e43 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from typing import Any, List, Set, Tuple +from typing import Any, List, Set, Tuple, cast from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken @@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_history_txn( - self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + self, + txn: LoggingTransaction, + room_id: str, + token: RoomStreamToken, + delete_local_events: bool, ) -> Set[int]: # Tables that should be pruned: # event_auth @@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """, (room_id,), ) - (min_depth,) = txn.fetchone() + (min_depth,) = cast(Tuple[int], txn.fetchone()) logger.info("[purge] updating room_depth to %d", min_depth) @@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index aac94fa46..dc6665237 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> None: """Record a mapping from an external user id to a mxid + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): external_id: str, user_id: str, ) -> None: + """ + Record a mapping from an external user id to a mxid. + + Note that the auth provider IDs (and the external IDs) are not validated + against configured IdPs as Synapse does not know its relationship to + external systems. For example, it might be useful to pre-configure users + before enabling a new IdP or an IdP might be temporarily offline, but + still valid. + + Args: + txn: The database transaction. + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ self.db_pool.simple_insert_txn( txn, @@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Replace mappings from external user ids to a mxid in a single transaction. All mappings are deleted and the new ones are created. + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: record_external_ids: List with tuple of auth_provider and external_id to record user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -1660,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_id=row[1], device_id=row[2], next_token_id=row[3], - has_next_refresh_token_been_refreshed=row[4], + # SQLite returns 0 or 1 for false/true, so convert to a bool. + has_next_refresh_token_been_refreshed=bool(row[4]), # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), expiry_ts=row[6], @@ -1676,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Set the successor of a refresh token, removing the existing successor if any. + This also deletes the predecessor refresh and access tokens, + since they cannot be valid anymore. + Args: token_id: ID of the refresh token to update. next_token_id: ID of its successor. """ - def _replace_refresh_token_txn(txn) -> None: + def _replace_refresh_token_txn(txn: LoggingTransaction) -> None: # First check if there was an existing refresh token old_next_token_id = self.db_pool.simple_select_one_onecol_txn( txn, @@ -1707,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): {"id": old_next_token_id}, ) + # Delete the previous refresh token, since we only want to keep the + # last 2 refresh tokens in the database. + # (The predecessor of the latest refresh token is still useful in + # case the refresh was interrupted and the client re-uses the old + # one.) + # This cascades to delete the associated access token. + self.db_pool.simple_delete_txn( + txn, "refresh_tokens", {"next_token_id": token_id} + ) + await self.db_pool.runInteraction( "replace_refresh_token", _replace_refresh_token_txn ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e2c27e594..36aa1092f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -53,8 +53,13 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: + # The latest event in the thread. latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. count: int + # True if the current user has sent an event to the thread. current_user_participated: bool @@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def _get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: - """Get the number of threaded replies and the latest reply (if any) for the given event. + ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: + """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. Args: event_ids: Summarize the thread related to this event ID. @@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore): A map of the thread summary each event. A missing event implies there are no threaded replies. - Each summary includes the number of items in the thread and the most - recent response. + Each summary is a tuple of: + The number of events in the thread. + The most recent event in the thread. + The most recent edit to the most recent event in the thread, if applicable. """ def _get_thread_summaries_txn( @@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore): # TODO Should this only allow m.room.message events. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, - # so ordering by topologica ordering + stream ordering desc will + # so ordering by topological ordering + stream ordering desc will # ensure we get the latest event in the thread. sql = """ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child @@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + # Check to see if any of those events are edited. + latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + # Map to the event IDs to the thread summary. # # There might not be a summary due to there not being a thread or @@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore): summary = None if latest_event: - summary = (counts[parent_event_id], latest_event) + latest_edit = latest_edits.get(latest_event_id) + summary = (counts[parent_event_id], latest_event, latest_edit) summaries[parent_event_id] = summary return summaries @@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore): ) for event_id, summary in summaries.items(): if summary: - thread_count, latest_thread_event = summary + thread_count, latest_thread_event, edit = summary results.setdefault( event_id, BundledAggregations() ).thread = _ThreadAggregation( latest_event=latest_thread_event, + latest_edit=edit, count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 95167116c..94068940b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Collection, Dict, List, Optional, @@ -1498,7 +1499,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( - self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: """Ensure that the room is stored in the table @@ -1511,7 +1512,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): has_auth_chain_index = await self.has_auth_chain_index(room_id) create_event = None - for e in auth_events: + for e in state_events: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break @@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): lock=False, ) + async def store_partial_state_room( + self, + room_id: str, + servers: Collection[str], + ) -> None: + """Mark the given room as containing events with partial state + + Args: + room_id: the ID of the room + servers: other servers known to be in the room + """ + await self.db_pool.runInteraction( + "store_partial_state_room", + self._store_partial_state_room_txn, + room_id, + servers, + ) + + @staticmethod + def _store_partial_state_room_txn( + txn: LoggingTransaction, room_id: str, servers: Collection[str] + ) -> None: + DatabasePool.simple_insert_txn( + txn, + table="partial_state_rooms", + values={ + "room_id": room_id, + }, + ) + DatabasePool.simple_insert_many_txn( + txn, + table="partial_state_rooms_servers", + keys=("room_id", "server_name"), + values=((room_id, s) for s in servers), + ) + async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ) -> None: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4489732fd..e48ec5f49 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) + @cachedList( + cached_method_name="get_rooms_for_user_with_stream_ordering", + list_name="user_ids", + ) + async def get_rooms_for_users_with_stream_ordering( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + """A batched version of `get_rooms_for_user_with_stream_ordering`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + return await self.db_pool.runInteraction( + "get_rooms_for_users_with_stream_ordering", + self._get_rooms_for_users_with_stream_ordering_txn, + user_ids, + ) + + def _get_rooms_for_users_with_stream_ordering_txn( + self, txn, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + + clause, args = make_in_list_sql_clause( + self.database_engine, + "c.state_key", + user_ids, + ) + + if self._current_state_events_membership_up_to_date: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.membership = ? + AND {clause} + """ + else: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND m.membership = ? + AND {clause} + """ + + txn.execute(sql, [Membership.JOIN] + args) + + result = {user_id: set() for user_id in user_ids} + for user_id, room_id, instance, stream_id in txn: + result[user_id].add( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + ) + + return {user_id: frozenset(v) for user_id, v in result.items()} + async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2d085a576..e23b11907 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -28,6 +28,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -114,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings" def __init__( self, @@ -146,6 +148,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) + self.db_pool.updates.register_background_update_handler( + self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings + ) + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -371,6 +377,27 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return num_rows + async def _background_delete_non_strings( + self, progress: JsonDict, batch_size: int + ) -> int: + """Deletes rows with non-string `value`s from `event_search` if using sqlite. + + Prior to Synapse 1.44.0, malformed events received over federation could cause integers + to be inserted into the `event_search` table when using sqlite. + """ + + def delete_non_strings_txn(txn: LoggingTransaction) -> None: + txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'") + + await self.db_pool.runInteraction( + self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn + ) + + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_DELETE_NON_STRINGS + ) + return 1 + class SearchStore(SearchBackgroundUpdateStore): def __init__( @@ -381,17 +408,19 @@ class SearchStore(SearchBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) - async def search_msgs(self, room_ids, search_term, keys): + async def search_msgs( + self, room_ids: Collection[str], search_term: str, keys: Iterable[str] + ) -> JsonDict: """Performs a full text search over events with given keys. Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", "content.name", "content.topic" Returns: - list of dicts + Dictionary of results """ clauses = [] @@ -499,10 +528,10 @@ class SearchStore(SearchBackgroundUpdateStore): self, room_ids: Collection[str], search_term: str, - keys: List[str], + keys: Iterable[str], limit, pagination_token: Optional[str] = None, - ) -> List[dict]: + ) -> JsonDict: """Performs a full text search over events with given keys. Args: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 2fb3e6519..417aef1db 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -42,6 +42,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + return v + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" @@ -62,11 +72,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Typically this happens if support for the room's version has been removed from Synapse. """ - return await self.db_pool.runInteraction( - "get_room_version_txn", - self.get_room_version_txn, - room_id, - ) + room_version_id = await self.get_room_version_id(room_id) + return _retrieve_and_check_room_version(room_id, room_version_id) def get_room_version_txn( self, txn: LoggingTransaction, room_id: str @@ -82,15 +89,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): removed from Synapse. """ room_version_id = self.get_room_version_id_txn(txn, room_id) - v = KNOWN_ROOM_VERSIONS.get(room_version_id) - - if not v: - raise UnsupportedRoomVersionError( - "Room %s uses a room version %s which is no longer supported" - % (room_id, room_version_id) - ) - - return v + return _retrieve_and_check_room_version(room_id, room_version_id) @cached(max_entries=10000) async def get_room_version_id(self, room_id: str) -> str: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f7c778bdf..e7fddd242 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = await self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] if is_in_room: - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] # Throw away users excluded from the directory. users_with_profile = { user_id: profile @@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): for user_id in users_to_work_on: if await self.should_include_local_user_in_dir(user_id): - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # technically it could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice sender can be # contacted. - if self.get_app_service_by_user_id(user) is not None: + if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] return False # We're opting to exclude appservice users (anyone matching the user @@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # they could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice users can be # contacted. - if self.get_if_app_services_interested_in_user(user): + if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] # TODO we might want to make this configurable for each app service return False # Support users are for diagnostics and should not appear in the user directory. - if await self.is_support_user(user): + if await self.is_support_user(user): # type: ignore[attr-defined] return False # Deactivated users aren't contactable, so should not appear in the user directory. try: - if await self.get_user_deactivated_status(user): + if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] return False except StoreError: # No such user in the users table. No need to do this when calling @@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = await self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] if hist_vis_ev: if ( hist_vis_ev.content.get("history_visibility") diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 428d66a61..7d543fdbe 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -427,21 +427,21 @@ class EventsPersistenceStorage: # NB: Assumes that we are only persisting events for one room # at a time. - # map room_id->list[event_ids] giving the new forward + # map room_id->set[event_ids] giving the new forward # extremities in each room - new_forward_extremeties = {} + new_forward_extremities: Dict[str, Set[str]] = {} # map room_id->(type,state_key)->event_id tracking the full # state in each room after adding these events. # This is simply used to prefill the get_current_state_ids # cache - current_state_for_room = {} + current_state_for_room: Dict[str, StateMap[str]] = {} # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each # room - state_delta_for_room = {} + state_delta_for_room: Dict[str, DeltaState] = {} # Set of remote users which were in rooms the server has left. We # should check if we still share any rooms and if not we mark their @@ -460,14 +460,13 @@ class EventsPersistenceStorage: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = ( + latest_event_ids = set( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) - latest_event_ids = set(latest_event_ids) if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -478,7 +477,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids len_1 = ( len(latest_event_ids) == 1 @@ -533,7 +532,7 @@ class EventsPersistenceStorage: # extremities, so we'll `continue` above and skip this bit.) assert new_latest_event_ids, "No forward extremities left!" - new_forward_extremeties[room_id] = new_latest_event_ids + new_forward_extremities[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -567,7 +566,7 @@ class EventsPersistenceStorage: ) if not is_still_joined: logger.info("Server no longer in room %s", room_id) - latest_event_ids = [] + latest_event_ids = set() current_state = {} delta.no_longer_in_room = True @@ -582,7 +581,7 @@ class EventsPersistenceStorage: chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, ) @@ -596,7 +595,7 @@ class EventsPersistenceStorage: room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: Collection[str], - ): + ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -906,9 +905,9 @@ class EventsPersistenceStorage: # Ideally we'd figure out a way of still being able to drop old # dummy events that reference local events, but this is good enough # as a first cut. - events_to_check = [event] + events_to_check: Collection[EventBase] = [event] while events_to_check: - new_events = set() + new_events: Set[str] = set() for event_to_check in events_to_check: if self.is_mine_id(event_to_check.sender): if event_to_check.type != EventTypes.Dummy: diff --git a/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql new file mode 100644 index 000000000..09305638e --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql @@ -0,0 +1,28 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- next_token_id is a foreign key reference, so previously required a table scan +-- when a row in the referenced table was deleted. +-- As it was self-referential and cascaded deletes, this led to O(t*n) time to +-- delete a row, where t: number of rows in the table and n: number of rows in +-- the ancestral 'chain' of access tokens. +-- +-- This index is partial since we only require it for rows which reference +-- another. +-- Performance was tested to be the same regardless of whether the index was +-- full or partial, but a partial index can be smaller. +CREATE INDEX refresh_tokens_next_token_id + ON refresh_tokens(next_token_id) + WHERE next_token_id IS NOT NULL; diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql new file mode 100644 index 000000000..815c0cc39 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql @@ -0,0 +1,41 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- rooms which we have done a partial-state-style join to +CREATE TABLE IF NOT EXISTS partial_state_rooms ( + room_id TEXT PRIMARY KEY, + FOREIGN KEY(room_id) REFERENCES rooms(room_id) +); + +-- a list of remote servers we believe are in the room +CREATE TABLE IF NOT EXISTS partial_state_rooms_servers ( + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + server_name TEXT NOT NULL, + UNIQUE(room_id, server_name) +); + +-- a list of events with partial state. We can't store this in the `events` table +-- itself, because `events` is meant to be append-only. +CREATE TABLE IF NOT EXISTS partial_state_events ( + -- the room_id is denormalised for efficient indexing (the canonical source is `events`) + room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id), + event_id TEXT NOT NULL REFERENCES events(event_id), + UNIQUE(event_id) +); + +CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx + ON partial_state_events (room_id); + + diff --git a/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite new file mode 100644 index 000000000..140df6526 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete rows with non-string `value`s from `event_search` if using sqlite. +-- +-- Prior to Synapse 1.44.0, malformed events received over federation could +-- cause integers to be inserted into the `event_search` table. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6805, 'event_search_sqlite_delete_non_strings', '{}'); diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py new file mode 100644 index 000000000..a2ec4fc26 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -0,0 +1,72 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds triggers to the partial_state_events tables to enforce uniqueness + +Triggers cannot be expressed in .sql files, so we have to use a separate file. +""" +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # complain if the room_id in partial_state_events doesn't match + # that in `events`. We already have a fk constraint which ensures that the event + # exists in `events`, so all we have to do is raise if there is a row with a + # matching stream_ordering but not a matching room_id. + if isinstance(database_engine, Sqlite3Engine): + cur.execute( + """ + CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id + BEFORE INSERT ON partial_state_events + FOR EACH ROW + BEGIN + SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events') + WHERE EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ); + END; + """ + ) + elif isinstance(database_engine, PostgresEngine): + cur.execute( + """ + CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$ + BEGIN + IF EXISTS ( + SELECT 1 FROM events + WHERE events.event_id = NEW.event_id + AND events.room_id != NEW.room_id + ) THEN + RAISE EXCEPTION 'Incorrect room_id in partial_state_events'; + END IF; + RETURN NEW; + END; + $BODY$ LANGUAGE plpgsql; + """ + ) + + cur.execute( + """ + CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events + FOR EACH ROW + EXECUTE PROCEDURE check_partial_state_events() + """ + ) + else: + raise NotImplementedError("Unknown database engine") diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 913448f0f..e79ecf64a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -204,13 +204,16 @@ class StateFilter: if get_all_members: # We want to return everything. return StateFilter.all() - else: + elif EventTypes.Member in self.types: # We want to return all non-members, but only particular # memberships return StateFilter( types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. @@ -528,6 +531,9 @@ class StateFilter: _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 21591d0bf..fb8fe1729 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -37,16 +37,18 @@ class _EventSourcesInner: account_data: AccountDataEventSource def get_sources(self) -> Iterator[Tuple[str, EventSource]]: - for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined] + for attribute in attr.fields(_EventSourcesInner): yield attribute.name, getattr(self, attribute.name) class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined] + # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since + # all the attributes of `_EventSourcesInner` are annotated. + *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def get_current_token(self) -> StreamToken: push_rules_key = self.store.get_max_push_rules_stream_id() diff --git a/synapse/types.py b/synapse/types.py index f89fb216a..53be3583a 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore + from synapse.storage.databases.main import DataStore, PurgeEventsStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -485,7 +485,7 @@ class RoomStreamToken: ) @classmethod - async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -502,7 +502,7 @@ class RoomStreamToken: instance_id = int(key) pos = int(value) - instance_name = await store.get_name_from_instance_id(instance_id) + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] instance_map[instance_name] = pos return cls( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 511f52534..58b4220ff 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure: Failure) -> Failure: - # defer.gatherResults and DeferredLists wrap failures. + # Deprecated: you probably just want to catch defer.FirstError and reraise + # the subFailure's value, which will do a better job of preserving stacktraces. + # (actually, you probably want to use yieldable_gather_results anyway) failure.trap(defer.FirstError) return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 3f7299aff..60c03a66f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -29,6 +29,7 @@ from typing import ( Hashable, Iterable, Iterator, + List, Optional, Set, Tuple, @@ -51,7 +52,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.util import Clock, unwrapFirstError +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): T = TypeVar("T") -def concurrently_execute( +async def concurrently_execute( func: Callable[[T], Any], args: Iterable[T], limit: int -) -> defer.Deferred: +) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -221,20 +222,14 @@ def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - return make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(_concurrently_execute_inner, value) - for value in itertools.islice(it, limit) - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) + await yieldable_gather_results( + _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) + ) -def yieldable_gather_results( - func: Callable, iter: Iterable, *args: Any, **kwargs: Any -) -> defer.Deferred: +async def yieldable_gather_results( + func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any +) -> List[T]: """Executes the function with each argument concurrently. Args: @@ -245,15 +240,30 @@ def yieldable_gather_results( **kwargs: Keyword arguments to be passed to each call to func Returns - Deferred[list]: Resolved when all functions have been invoked, or errors if - one of the function calls fails. + A list containing the results of the function """ - return make_deferred_yieldable( - defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], - consumeErrors=True, + try: + return await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) ) - ).addErrback(unwrapFirstError) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None T1 = TypeVar("T1") @@ -545,7 +555,10 @@ class ReadWriteLock: finally: with PreserveLoggingContext(): new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: + # `self.key_to_current_writer[key]` may be missing if there was another + # writer waiting for us and it completed entirely within the + # `new_defer.callback()` call above. + if self.key_to_current_writer.get(key) == new_defer: self.key_to_current_writer.pop(key) return _ctx_manager() @@ -655,3 +668,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: return value return DoneAwaitable(value) + + +def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`, + but will not propagate cancellation through to the original. When cancelled, + the new `Deferred` will fail with a `CancelledError` and will not follow the + Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap + the new `Deferred`. + """ + new_deferred: defer.Deferred[T] = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 15debd6c4..1cbc180ed 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n class EvictionReason(Enum): size = auto() time = auto() + invalidation = auto() @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index df4fb156c..1cdead02f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import inspect import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes - the tuple of missing keys to the wrapped function. + the set of missing keys to the wrapped function. - Once wrapped, the function returns a Deferred which resolves to the list - of results. + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from + input key to output value. """ def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - cache.set(key, deferred, callback=invalidate_callback) + cached_defers.append( + cache.set(key, deferred, callback=invalidate_callback) + ) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a - # a dict. We can now resolve the observable deferreds in - # the cache and update our own result map. - for e in missing: + # the wrapped function has completed. It returns a dict. + # We can now update our own result map, and then resolve the + # observable deferreds in the cache. + for e, d1 in deferreds_map.items(): val = res.get(e, None) - deferreds_map[e].callback(val) + # make sure we update the results map before running the + # deferreds, because as soon as we run the last deferred, the + # gatherResults() below will complete and return the result + # dict to our caller. results[e] = val + d1.callback(val) - def errback(f: Failure) -> Failure: - # the wrapped function has failed. Invalidate any cache - # entries we're supposed to be populating, and fail - # their deferreds. - for e in missing: - key = arg_to_cache_key(e) - cache.invalidate(key) - deferreds_map[e].errback(f) - - # return the failure, to propagate to our caller. - return f + def errback_all(f: Failure) -> None: + # the wrapped function has failed. Propagate the failure into + # the cache, which will invalidate the entry, and cause the + # relevant cached_deferreds to fail, which will propagate the + # failure to our caller. + for d1 in deferreds_map.values(): + d1.errback(f) args_to_call = dict(arg_dict) - # copy the missing set before sending it to the callee, to guard against - # modification. - args_to_call[self.list_name] = tuple(missing) + args_to_call[self.list_name] = missing - cached_defers.append( - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback) - ) + # dispatch the call, and attach the two handlers + defer.maybeDeferred( + preserve_fn(self.orig), **args_to_call + ).addCallbacks(complete_all, errback_all) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 67ee4c693..c6a5d0dfc 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]): raise KeyError(key) return default + if self.iterable: + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + else: + self.metrics.inc_evictions(EvictionReason.invalidation) + return value.value def __contains__(self, key: KT) -> bool: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 7548b3854..45ff0de63 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -560,8 +560,10 @@ class LruCache(Generic[KT, VT]): def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: - delete_node(node) + evicted_len = delete_node(node) cache.pop(node.key, None) + if metrics: + metrics.inc_evictions(EvictionReason.invalidation, evicted_len) return node.value else: return default diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py new file mode 100644 index 000000000..3a1f6b3c7 --- /dev/null +++ b/synapse/util/check_dependencies.py @@ -0,0 +1,127 @@ +import logging +from typing import Iterable, NamedTuple, Optional + +from packaging.requirements import Requirement + +DISTRIBUTION_NAME = "matrix-synapse" + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + + +class DependencyException(Exception): + @property + def message(self) -> str: + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) + + @property + def dependencies(self) -> Iterable[str]: + for i in self.args[0]: + yield '"' + i + '"' + + +EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) + + +class Dependency(NamedTuple): + requirement: Requirement + must_be_installed: bool + + +def _generic_dependencies() -> Iterable[Dependency]: + """Yield pairs (requirement, must_be_installed).""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # https://packaging.pypa.io/en/latest/markers.html#usage notes that + # > Evaluating an extra marker with no environment is an error + # so we pass in a dummy empty extra value here. + must_be_installed = req.marker is None or req.marker.evaluate({"extra": ""}) + yield Dependency(req, must_be_installed) + + +def _dependencies_for_extra(extra: str) -> Iterable[Dependency]: + """Yield additional dependencies needed for a given `extra`.""" + requirements = metadata.requires(DISTRIBUTION_NAME) + assert requirements is not None + for raw_requirement in requirements: + req = Requirement(raw_requirement) + # Exclude mandatory deps by only selecting deps needed with this extra. + if ( + req.marker is not None + and req.marker.evaluate({"extra": extra}) + and not req.marker.evaluate({"extra": ""}) + ): + yield Dependency(req, True) + + +def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: + if extra: + return f"Need {requirement.name} for {extra}, but it is not installed" + else: + return f"Need {requirement.name}, but it is not installed" + + +def _incorrect_version( + requirement: Requirement, got: str, extra: Optional[str] = None +) -> str: + if extra: + return f"Need {requirement} for {extra}, but got {requirement.name}=={got}" + else: + return f"Need {requirement}, but got {requirement.name}=={got}" + + +def check_requirements(extra: Optional[str] = None) -> None: + """Check Synapse's dependencies are present and correctly versioned. + + If provided, `extra` must be the name of an pacakging extra (e.g. "saml2" in + `pip install matrix-synapse[saml2]`). + + If `extra` is None, this function checks that + - all mandatory dependencies are installed and correctly versioned, and + - each optional dependency that's installed is correctly versioned. + + If `extra` is not None, this function checks that + - the dependencies needed for that extra are installed and correctly versioned. + + :raises DependencyException: if a dependency is missing or incorrectly versioned. + :raises ValueError: if this extra does not exist. + """ + # First work out which dependencies are required, and which are optional. + if extra is None: + dependencies = _generic_dependencies() + elif extra in EXTRAS: + dependencies = _dependencies_for_extra(extra) + else: + raise ValueError(f"Synapse does not provide the feature '{extra}'") + + deps_unfulfilled = [] + errors = [] + + for (requirement, must_be_installed) in dependencies: + try: + dist: metadata.Distribution = metadata.distribution(requirement.name) + except metadata.PackageNotFoundError: + if must_be_installed: + deps_unfulfilled.append(requirement.name) + errors.append(_not_installed(requirement, extra)) + else: + if not requirement.specifier.contains(dist.version): + deps_unfulfilled.append(requirement.name) + errors.append(_incorrect_version(requirement, dist.version, extra)) + + if deps_unfulfilled: + for err in errors: + logging.error(err) + + raise DependencyException(deps_unfulfilled) diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index de04f34e4..031880ec3 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -20,7 +20,7 @@ import os import signal import sys from types import FrameType, TracebackType -from typing import NoReturn, Type +from typing import NoReturn, Optional, Type def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: @@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # also catch any other uncaught exceptions before we get that far.) def excepthook( - type_: Type[BaseException], value: BaseException, traceback: TracebackType + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) @@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: + def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon." % signum) sys.exit(0) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 1f18654d4..6d4b0b7c5 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, Generator, List, TypeVar +from typing import Any, Callable, Generator, List, TypeVar, cast from twisted.internet import defer from twisted.internet.defer import Deferred @@ -174,7 +174,9 @@ def _check_yield_points( ) ) changes.append(err) - return getattr(e, "value", None) + # The `StopIteration` or `_DefGen_Return` contains the return value from the + # generator. + return cast(T, e.value) frame = gen.gi_frame diff --git a/synctl b/synctl index 0e54f4847..1ab36949c 100755 --- a/synctl +++ b/synctl @@ -37,6 +37,13 @@ YELLOW = "\x1b[1;33m" RED = "\x1b[1;31m" NORMAL = "\x1b[m" +SYNCTL_CACHE_FACTOR_WARNING = """\ +Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do +one of the following: + - Either set the environment variable 'SYNAPSE_CACHE_FACTOR' + - or set 'caches.global_factor' in the homeserver config. +--------------------------------------------------------------------------------""" + def pid_running(pid): try: @@ -228,6 +235,7 @@ def main(): start_stop_synapse = True if cache_factor: + write(SYNCTL_CACHE_FACTOR_WARNING) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4b53b6d40..3e0578992 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -26,8 +28,10 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.appservice import ApplicationService +from synapse.server import HomeServer from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock @@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders class AuthTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.store = Mock() - hs.get_datastore = Mock(return_value=self.store) + hs.datastores.main = self.store hs.get_auth_handler().store = self.store self.auth = Auth(hs) @@ -67,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): self.store.get_user_by_access_token = simple_async_mock(None) @@ -105,7 +109,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -124,7 +128,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(requester.user.to_string(), self.test_user) + self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): from netaddr import IPSet @@ -191,7 +195,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -238,10 +242,10 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals( + self.assertEqual( requester.user.to_string(), masquerading_user_id.decode("utf8") ) - self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8")) @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): @@ -271,8 +275,8 @@ class AuthTestCase(unittest.HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) - self.assertEquals(failure.value.code, 400) - self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + self.assertEqual(failure.value.code, 400) + self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): self.store.get_user_by_access_token = simple_async_mock( @@ -305,7 +309,7 @@ class AuthTestCase(unittest.HomeserverTestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) - self.assertEquals(self.store.insert_client_ip.call_count, 2) + self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( @@ -365,9 +369,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) @@ -469,9 +473,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no @@ -484,9 +488,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.value.code, 403) + self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.code, 403) def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index b7fc33dc9..8c3354ce3 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,6 +18,7 @@ from unittest.mock import patch import jsonschema +from frozendict import frozendict from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError @@ -40,7 +41,7 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main def test_errors_on_invalid_filters(self): invalid_filters = [ @@ -327,6 +328,15 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertFalse(Filter(self.hs, definition)._check(event)) + # check it works with frozendicts too + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content=frozendict({EventContentFields.LABELS: ["#fun"]}), + ) + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} event = MockEvent( @@ -364,7 +374,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -388,7 +398,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_presence(events=events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -407,7 +417,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events=events)) - self.assertEquals(events, results) + self.assertEqual(events, results) def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -428,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) results = self.get_success(user_filter.filter_room_state(events)) - self.assertEquals([], results) + self.assertEqual([], results) def test_filter_rooms(self): definition = { @@ -444,7 +454,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) - self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): @@ -486,7 +496,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): Filter(self.hs, definition)._check_event_relations(events) ) ) - self.assertEquals(filtered_events, events[1:]) + self.assertEqual(filtered_events, events[1:]) def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} @@ -497,8 +507,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter_id, 0) - self.assertEquals( + self.assertEqual(filter_id, 0) + self.assertEqual( user_filter_json, ( self.get_success( @@ -524,6 +534,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals(filter.get_filter_json(), user_filter_json) + self.assertEqual(filter.get_filter_json(), user_filter_json) - self.assertRegexpMatches(repr(filter), r"") + self.assertRegex(repr(filter), r"") diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dcf0110c1..483d5463a 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -8,25 +8,25 @@ from tests import unittest class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( @@ -39,25 +39,25 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( @@ -70,29 +70,29 @@ class TestRatelimiter(unittest.HomeserverTestCase): as_requester = create_requester("@user:example.com", app_service=appservice) limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) - self.assertEquals(-1, time_allowed) + self.assertEqual(-1, time_allowed) def test_allowed_via_ratelimit(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # Shouldn't raise @@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) # First attempt should be allowed @@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_pruning(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1 ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): """Test that users that have ratelimiting disabled in the DB aren't ratelimited. """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = "@user:test" requester = create_requester(user_id) @@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_multiple_actions(self): limiter = Ratelimiter( - store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3 + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -246,7 +246,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0) ) self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that, after doing these 3 actions, we can't do any more action without # waiting. @@ -254,7 +254,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) ) self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting we can do only 1 action. allowed, time_allowed = self.get_success_or_raise( @@ -269,7 +269,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) @@ -277,7 +277,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertFalse(allowed) # The time allowed doesn't change despite allowed being False because, while we # don't allow 2 actions, we could still do 1. - self.assertEquals(10.0, time_allowed) + self.assertEqual(10.0, time_allowed) # Test that after waiting a bit more we can do 2 actions. allowed, time_allowed = self.get_success_or_raise( @@ -286,4 +286,4 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertTrue(allowed) # The time allowed is the current time because we could still repeat the action # once. - self.assertEquals(20.0, time_allowed) + self.assertEqual(20.0, time_allowed) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 19eb4c79d..df731eb59 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -32,7 +32,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -40,7 +40,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -51,21 +51,21 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 29 days. The user has now not posted for 29 days. self.reactor.advance(29 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 30 days. self.reactor.advance(ONE_DAY_IN_SECONDS) # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_minimum_usage_using_default_config(self): @@ -84,7 +84,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Advance 30 days (+ 1 second, because strict inequality causes issues if we are @@ -92,7 +92,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) # Send a message (this counts as activity) @@ -103,14 +103,14 @@ class PhoneHomeTestCase(HomeserverTestCase): self.reactor.advance(2 * 60 * 60) # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance 27 days. The user has now not posted for 27 days. self.reactor.advance(27 * ONE_DAY_IN_SECONDS) # The user is still counted. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) # Advance another day. The user has now not posted for 28 days. @@ -119,7 +119,7 @@ class PhoneHomeTestCase(HomeserverTestCase): # The user is now no longer counted in R30. # (This is because the user_ips table has been pruned, which by default # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) def test_r30_user_must_be_retained_for_at_least_a_month(self): @@ -135,7 +135,7 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "message", tok=access_token) # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 0}) for _ in range(30): @@ -144,14 +144,16 @@ class PhoneHomeTestCase(HomeserverTestCase): self.helper.send(room_id, "I'm still here", tok=access_token) # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success( + self.hs.get_datastores().main.count_r30_users() + ) self.assertEqual(r30_results, {"all": 0}) self.reactor.advance(ONE_DAY_IN_SECONDS) self.helper.send(room_id, "Still here!", tok=access_token) # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) @@ -196,7 +198,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the R30 results do not count that user. r30_results = self.get_success(store.count_r30v2_users()) @@ -275,7 +277,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check the user does not contribute to R30 yet. r30_results = self.get_success(store.count_r30v2_users()) @@ -347,7 +349,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # (user_daily_visits is updated every 5 minutes using a looping call.) self.reactor.advance(FIVE_MINUTES_IN_SECONDS) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user does not contribute to R30v2, even though it's been # more than 30 days since registration. diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8fb6687f8..1cbb05935 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,8 +68,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made + self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed def test_single_service_down(self): @@ -92,9 +94,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(0, txn.send.call_count) # txn not sent though - self.assertEquals(0, txn.complete.call_count) # or completed + self.assertEqual(0, txn.send.call_count) # txn not sent though + self.assertEqual(0, txn.complete.call_count) # or completed def test_single_service_up_txn_not_sent(self): # Test: The AS is up and the txn is not sent. A Recoverer is made and @@ -114,12 +118,17 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) - self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made - self.assertEquals(1, self.recoverer.recover.call_count) # and invoked - self.assertEquals(1, len(self.txnctrl.recoverers)) # and stored - self.assertEquals(0, txn.complete.call_count) # txn not completed + self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made + self.assertEqual(1, self.recoverer.recover.call_count) # and invoked + self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored + self.assertEqual(0, txn.complete.call_count) # txn not completed self.store.set_appservice_state.assert_called_once_with( service, ApplicationServiceState.DOWN # service marked as down ) @@ -152,17 +161,17 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(True) txn.complete = simple_async_mock(None) # wait for exp backoff self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(1, txn.complete.call_count) # 2 because it needs to get None to know there are no more txns - self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(2, self.store.get_oldest_unsent_txn.call_count) self.callback.assert_called_once_with(self.recoverer) - self.assertEquals(self.recoverer.service, self.service) + self.assertEqual(self.recoverer.service, self.service) def test_recover_retry_txn(self): txn = Mock() @@ -178,26 +187,26 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() - self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) + self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = simple_async_mock(False) txn.complete = simple_async_mock(None) self.clock.advance_time(2) - self.assertEquals(1, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(1, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(4) - self.assertEquals(2, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(2, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) self.clock.advance_time(8) - self.assertEquals(3, txn.send.call_count) - self.assertEquals(0, txn.complete.call_count) - self.assertEquals(0, self.callback.call_count) + self.assertEqual(3, txn.send.call_count) + self.assertEqual(0, txn.complete.call_count) + self.assertEqual(0, self.callback.call_count) txn.send = simple_async_mock(True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) - self.assertEquals(1, txn.send.call_count) # new mock reset call count - self.assertEquals(1, txn.complete.call_count) + self.assertEqual(1, txn.send.call_count) # new mock reset call count + self.assertEqual(1, txn.complete.call_count) self.callback.assert_called_once_with(self.recoverer) @@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -231,12 +240,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): # Tests that each service has its own queue, and that they don't block @@ -261,16 +272,16 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): srv_1_defer = defer.Deferred() @@ -288,28 +299,38 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) - self.assertEquals(3, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) + self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -324,15 +345,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) - self.assertEquals(1, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.assertEqual(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() @@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) - self.assertEquals(2, self.txn_ctrl.send.call_count) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index a72a0103d..694020fbe 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -63,14 +63,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", @@ -97,14 +97,14 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) - self.assertEquals( + self.assertEqual( event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) - self.assertEquals( + self.assertEqual( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a..d00ef24ca 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -76,7 +76,7 @@ class FakeRequest: @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): - self.assertEquals(getattr(current_context(), "request", None), expected) + self.assertEqual(getattr(current_context(), "request", None), expected) return val def test_verify_json_objects_for_server_awaits_previous_requests(self): @@ -96,7 +96,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def first_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_11") + # self.assertEqual(current_context().request.id, "context_11") self.assertEqual(server_name, "server10") self.assertEqual(key_ids, [get_key_id(key1)]) self.assertEqual(minimum_valid_until_ts, 0) @@ -137,7 +137,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def second_lookup_fetch( server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: - # self.assertEquals(current_context().request.id, "context_12") + # self.assertEqual(current_context().request.id, "context_12") return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() @@ -179,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -272,7 +272,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.get_datastore().store_server_verify_keys( + r = self.hs.get_datastores().main.store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -448,7 +448,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -564,7 +564,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) @@ -683,7 +683,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( - self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) res = key_json[lookup_triplet] self.assertEqual(len(res), 1) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index ca27388ae..defbc68c1 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.user_id = self.register_user("u1", "pass") diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 1dea09e48..45e3395b3 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -395,7 +395,7 @@ class SerializeEventTestCase(unittest.TestCase): return serialize_event(ev, 1479807801915, only_event_fields=fields) def test_event_fields_works_with_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] ), @@ -403,7 +403,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -416,7 +416,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -429,7 +429,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_works_with_nested_dot_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -445,7 +445,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_unknown_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -458,7 +458,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_non_dict_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -471,7 +471,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_nops_with_array_keys(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( sender="@alice:localhost", @@ -484,7 +484,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_all_fields_if_empty(self): - self.assertEquals( + self.assertEqual( self.serialize( MockEvent( type="foo", diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index e40ef9587..9f1115dd2 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -50,19 +50,19 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) # Artificially raise the complexity - store = self.hs.get_datastore() + store = self.hs.get_datastores().main store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) @@ -149,7 +149,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = ( + self.hs.get_datastores().main.get_current_state_event_counts = ( lambda x: make_awaitable(600) ) diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index f0aa8ed9d..2873b4d43 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -64,7 +64,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): Dictionary of { event_id: str, stream_ordering: int } """ event_id, stream_ordering = self.get_success( - self.hs.get_datastore().db_pool.execute( + self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", None, """ @@ -125,7 +125,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.pump() lsso_1 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -141,7 +141,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] lsso_2 = self.get_success( - self.hs.get_datastore().get_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.get_destination_last_successful_stream_ordering( "host2" ) ) @@ -216,7 +216,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # let's also clear any backoffs self.get_success( - self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + self.hs.get_datastores().main.set_destination_retry_timings( + "host2", None, 0, 0 + ) ) # bring the remote online and clear the received pdu list @@ -296,13 +298,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # destination_rooms should already be populated, but let us pretend that we already # sent (successfully) up to and including event id 2 - event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2)) # also fetch event 5 so we know its last_successful_stream_ordering later - event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering ) ) @@ -359,7 +361,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # ASSERT: # - All servers are up to date so none should have outstanding catch-up outstanding_when_successful = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(outstanding_when_successful, []) @@ -370,7 +372,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - Mark zzzerver as being backed-off from now = self.clock.time_msec() self.get_success( - self.hs.get_datastore().set_destination_retry_timings( + self.hs.get_datastores().main.set_destination_retry_timings( "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day ) ) @@ -382,14 +384,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # - all remotes are outstanding # - they are returned in batches of 25, in order outstanding_1 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None) ) self.assertEqual(len(outstanding_1), 25) self.assertEqual(outstanding_1, server_names[0:25]) outstanding_2 = self.get_success( - self.hs.get_datastore().get_catch_up_outstanding_destinations( + self.hs.get_datastores().main.get_catch_up_outstanding_destinations( outstanding_1[-1] ) ) @@ -457,7 +459,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ) self.get_success( - self.hs.get_datastore().set_destination_last_successful_stream_ordering( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering ) ) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py new file mode 100644 index 000000000..ec8864daf --- /dev/null +++ b/tests/federation/test_federation_client.py @@ -0,0 +1,149 @@ +# Copyright 2022 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +import twisted.web.client +from twisted.internet import defer +from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import FederatingHomeserverTestCase + + +class FederationClientTest(FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + super().prepare(reactor, clock, homeserver) + + # mock out the Agent used by the federation client, which is easier than + # catching the HTTPS connection and do the TLS stuff. + self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) + homeserver.get_federation_http_client().agent = self._mock_agent + + def test_get_room_state(self): + creator = f"@creator:{self.OTHER_SERVER_NAME}" + test_room_id = "!room_id" + + # mock up some events to use in the response. + # In real life, these would have things in `prev_events` and `auth_events`, but that's + # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. + create_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": {"creator": creator}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 500, + } + ) + member_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 600, + } + ) + pl_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.power_levels", + "sender": creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.return_value = defer.succeed( + _mock_response( + { + "pdus": [ + create_event_dict, + member_event_dict, + pl_event_dict, + ], + "auth_chain": [ + create_event_dict, + member_event_dict, + ], + } + ) + ) + + # now fire off the request + state_resp, auth_resp = self.get_success( + self.hs.get_federation_client().get_room_state( + "yet_another_server", + test_room_id, + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + # ... and that the response is correct. + + # the auth_resp should be empty because all the events are also in state + self.assertEqual(auth_resp, []) + + # all of the events should be returned in state_resp, though not necessarily + # in the same order. We just check the type on the assumption that if the type + # is right, so is the rest of the event. + self.assertCountEqual( + [e.type for e in state_resp], + ["m.room.create", "m.room.member", "m.room.power_levels"], + ) + + +def _mock_response(resp: JsonDict): + body = json.dumps(resp).encode("utf-8") + + def deliver_body(p: Protocol): + p.dataReceived(body) + p.connectionLost(Failure(twisted.web.client.ResponseDone())) + + response = mock.Mock( + code=200, + phrase=b"OK", + headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), + length=len(body), + deliverBody=deliver_body, + ) + mock.seal(response) + return response diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b2376e2db..60e0c31f4 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -176,7 +176,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def get_users_who_share_room_with_user(user_id): return defer.succeed({"@user2:host2"}) - hs.get_datastore().get_users_who_share_room_with_user = ( + hs.get_datastores().main.get_users_who_share_room_with_user = ( get_users_who_share_room_with_user ) @@ -395,7 +395,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server @@ -445,7 +445,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # run the prune job self.reactor.advance(10) self.get_success( - self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1) ) # recover the server diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index d084919ef..30e7e5093 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -125,7 +125,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual( channel.json_body["room_version"], @@ -157,7 +157,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -189,7 +189,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body def test_send_join(self): @@ -209,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # we should get complete room state back returned_state = [ @@ -266,7 +266,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # expect a reduced room state returned_state = [ diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index a7031a55f..c2320ce13 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase): self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) + self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertEqual(parsed_response.servers_in_room, None, parsed_response) + + def test_partial_state(self) -> None: + """Check that the partial_state flag is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = { + "org.matrix.msc3706.partial_state": True, + } + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertTrue(parsed_response.partial_state) + + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 686f42ab4..648a01618 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -169,7 +169,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase): self.assertIn(event_type, expected_room_state) # Check the state content matches - self.assertEquals( + self.assertEqual( expected_room_state[event_type]["content"], event["content"] ) @@ -198,7 +198,7 @@ class FederationKnockingTestCase( ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # We're not going to be properly signing events as our remote homeserver is fake, # therefore disable event signature checks. @@ -256,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -266,11 +266,11 @@ class FederationKnockingTestCase( knock_event = channel.json_body["event"] # Check that the event has things we expect in it - self.assertEquals(knock_event["room_id"], room_id) - self.assertEquals(knock_event["sender"], fake_knocking_user_id) - self.assertEquals(knock_event["state_key"], fake_knocking_user_id) - self.assertEquals(knock_event["type"], EventTypes.Member) - self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK) + self.assertEqual(knock_event["room_id"], room_id) + self.assertEqual(knock_event["sender"], fake_knocking_user_id) + self.assertEqual(knock_event["state_key"], fake_knocking_user_id) + self.assertEqual(knock_event["type"], EventTypes.Member) + self.assertEqual(knock_event["content"]["membership"], Membership.KNOCK) # Turn the event json dict into a proper event. # We won't sign it properly, but that's OK as we stub out event auth in `prepare` @@ -294,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index eb62addda..5f001c33b 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -13,7 +13,7 @@ # limitations under the License. from tests import unittest -from tests.unittest import override_config +from tests.unittest import DEBUG, override_config class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @@ -26,7 +26,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(403, channel.code) + self.assertEqual(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): @@ -37,4 +37,22 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "GET", "/_matrix/federation/v1/publicRooms", ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) + + @DEBUG + def test_edu_debugging_doesnt_explode(self): + """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled. + + Remove this when we strip out issue_8631_logger. + """ + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn_id_1234/", + content={ + "edus": [ + {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + ], + "pdus": [], + }, + ) + self.assertEqual(200, channel.code) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index fe57ff267..072e6bbcd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -38,7 +46,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore.return_value = self.mock_store + hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( @@ -139,8 +147,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str ) - self.assertEquals(result.room_id, room_id) - self.assertEquals(result.servers, servers) + self.assertEqual(result.room_id, room_id) + self.assertEqual(result.servers, servers) def test_get_3pe_protocols_no_appservices(self): self.mock_store.get_app_services.return_value = [] @@ -148,7 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_no_protocols(self): service = self._mkservice(False, []) @@ -157,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_not_called() - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_protocol_no_response(self): service = self._mkservice(False, ["my-protocol"]) @@ -169,7 +177,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals(response, {}) + self.assertEqual(response, {}) def test_get_3pe_protocols_select_one_protocol(self): service = self._mkservice(False, ["my-protocol"]) @@ -183,7 +191,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -199,7 +207,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.get_3pe_protocol.assert_called_once_with( service, "my-protocol" ) - self.assertEquals( + self.assertEqual( response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} ) @@ -214,7 +222,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) self.mock_as_api.get_3pe_protocol.assert_called() - self.assertEquals( + self.assertEqual( response, { "my-protocol": {"x-protocol-data": 42, "instances": []}, @@ -246,7 +254,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.ensureDeferred(self.handler.get_3pe_protocols()) ) # It's expected that the second service's data doesn't appear in the response - self.assertEquals( + self.assertEqual( response, { "my-protocol": { @@ -355,7 +363,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) # A user on the homeserver. self.local_user_device_id = "local_device" @@ -426,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -494,7 +511,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create a fake device per message. We can't send to-device messages to # a device that doesn't exist. self.get_success( - self.hs.get_datastore().db_pool.simple_insert_many( + self.hs.get_datastores().main.db_pool.simple_insert_many( desc="test_application_services_receive_burst_of_to_device", table="devices", keys=("user_id", "device_id"), @@ -510,7 +527,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Seed the device_inbox table with our fake messages self.get_success( - self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {}) ) # Now have local_user send a final to-device message to exclusive_as_user. All unsent @@ -538,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -580,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastores().main.add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastores().main + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastores().main.store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args, _last_kwargs = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615..0c6e55e72 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -129,7 +129,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) @@ -140,7 +140,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) self.get_failure( @@ -156,7 +156,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) @@ -175,7 +175,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # If in monthly active cohort - self.hs.get_datastore().user_last_seen_monthly_active = Mock( + self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( @@ -192,7 +192,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception @@ -202,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.hs.get_datastore().get_monthly_active_count = Mock( + self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) self.get_success( diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff894..a26722884 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -77,7 +77,7 @@ class CasHandlerTestCase(HomeserverTestCase): def test_map_cas_user_to_existing_user(self): """Existing users can log in with CAS account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 01096a158..ddda36c5a 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 43031e07e..683677fd0 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def prepare(self, reactor, clock, hs): @@ -263,7 +263,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_dehydrate_and_rehydrate_device(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 0ea4e753e..6e403a87c 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,7 +46,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler = hs.get_directory_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.my_room = RoomAlias.from_string("#my-room:test") self.your_room = RoomAlias.from_string("#your-room:test") @@ -63,7 +63,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.my_room)) - self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) + self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): self.mock_federation.make_query.return_value = make_awaitable( @@ -72,7 +72,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_association(self.remote_room)) - self.assertEquals( + self.assertEqual( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result ) self.mock_federation.make_query.assert_called_with( @@ -94,7 +94,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + self.assertEqual({"room_id": "!8765asdf:test", "servers": ["test"]}, response) class TestCreateAlias(unittest.HomeserverTestCase): @@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -224,7 +224,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -243,7 +243,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.admin_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -269,7 +269,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): create_requester(self.test_user), self.room_alias ) ) - self.assertEquals(self.room_id, result) + self.assertEqual(self.room_id, result) # Confirm the alias is gone. self.get_failure( @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -411,7 +411,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23test%3Atest", {"room_id": room_id}, ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) @@ -421,7 +421,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23unofficial_test%3Atest", {"room_id": room_id}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def test_denied_during_creation(self): """A room alias that is not allowed should be rejected during creation.""" @@ -443,8 +443,8 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): "GET", b"directory/room/%23unofficial_test%3Atest", ) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(channel.json_body["room_id"], room_id) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(channel.json_body["room_id"], room_id) class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): @@ -572,7 +572,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() @@ -585,7 +585,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list is enabled so we should get some results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) self.room_list_handler.enable_room_list_search = False @@ -593,7 +593,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we should get no results channel = self.make_request("GET", b"publicRooms") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms @@ -601,4 +601,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 734ed84d7..9338ab92e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = hs.get_e2e_keys_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 496b58172..e8b4e39d1 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state_store = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 5816295d8..f4f7ab484 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token( + self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a552d8182..e8418b663 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -856,7 +856,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user3 = UserID.from_string("@test_user_3:test") self.get_success( store.register_user(user_id=user3.to_string(), password_hash=None) @@ -872,7 +872,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) @@ -996,7 +996,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a6..49d832de8 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,7 @@ class CustomAuthProvider: def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -122,7 +122,7 @@ class PasswordCustomAuthProvider: auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) - # Initiate the UIA flow. username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, - ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) + res = self._do_uia_assert_mock_not_called(username, m) - # Check that the callback hasn't been called yet. - m.assert_not_called() - - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 671dc7d08..6ddec9ecf 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_offline_to_online(self): wheel_timer = Mock() @@ -61,11 +61,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -104,11 +104,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -149,11 +149,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertTrue(federation_ping) self.assertTrue(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 3) + self.assertEqual(wheel_timer.insert.call_count, 3) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -191,11 +191,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertTrue(persist_and_notify) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 2) + self.assertEqual(wheel_timer.insert.call_count, 2) wheel_timer.insert.assert_has_calls( [ call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), @@ -227,10 +227,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): self.assertFalse(persist_and_notify) self.assertFalse(federation_ping) self.assertFalse(state.currently_active) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -259,10 +259,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) - self.assertEquals(wheel_timer.insert.call_count, 0) + self.assertEqual(wheel_timer.insert.call_count, 0) def test_online_to_idle(self): wheel_timer = Mock() @@ -281,12 +281,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): ) self.assertTrue(persist_and_notify) - self.assertEquals(new_state.state, state.state) - self.assertEquals(state.last_federation_update_ts, now) - self.assertEquals(new_state.state, state.state) - self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEqual(new_state.state, state.state) + self.assertEqual(state.last_federation_update_ts, now) + self.assertEqual(new_state.state, state.state) + self.assertEqual(new_state.status_msg, state.status_msg) - self.assertEquals(wheel_timer.insert.call_count, 1) + self.assertEqual(wheel_timer.insert.call_count, 1) wheel_timer.insert.assert_has_calls( [ call( @@ -357,8 +357,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) + self.assertEqual(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -380,8 +380,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.BUSY) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.BUSY) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" @@ -399,8 +399,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" @@ -420,8 +420,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.ONLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.ONLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" @@ -440,7 +440,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -477,8 +477,8 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) - self.assertEquals(new_state.state, PresenceState.OFFLINE) - self.assertEquals(new_state.status_msg, status_msg) + self.assertEqual(new_state.state, PresenceState.OFFLINE) + self.assertEqual(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" @@ -497,7 +497,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(state, new_state) + self.assertEqual(state, new_state) class PresenceHandlerTestCase(unittest.HomeserverTestCase): @@ -891,7 +891,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2 = EventBuilderFactory(hs) # self.event_builder_for_2.hostname = "test2" - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 60235e569..972cbac6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,7 +48,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") self.bob = UserID.from_string("@4567:test") @@ -65,7 +65,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.frank)) - self.assertEquals("Frank", displayname) + self.assertEqual("Frank", displayname) def test_set_my_name(self): self.get_success( @@ -74,7 +74,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -90,7 +90,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -118,7 +118,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( ( self.get_success( self.store.get_profile_displayname(self.frank.localpart) @@ -150,7 +150,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): displayname = self.get_success(self.handler.get_displayname(self.alice)) - self.assertEquals(displayname, "Alice") + self.assertEqual(displayname, "Alice") self.mock_federation.make_query.assert_called_with( destination="remote", query_type="profile", @@ -172,7 +172,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals({"displayname": "Caroline"}, response) + self.assertEqual({"displayname": "Caroline"}, response) def test_get_my_avatar(self): self.get_success( @@ -182,7 +182,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) - self.assertEquals("http://my.server/me.png", avatar_url) + self.assertEqual("http://my.server/me.png", avatar_url) def test_set_my_avatar(self): self.get_success( @@ -193,7 +193,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) @@ -207,7 +207,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -235,7 +235,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) @@ -325,7 +325,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5de89c873..5081b9757 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -314,4 +314,4 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ): """Tests that the _filter_out_hidden returns the expected output""" filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") - self.assertEquals(filtered_events, expected_output) + self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index cd6f2c77a..45fd30cf4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -154,7 +154,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.lots_of_users = 100 self.small_number_of_users = 1 @@ -167,12 +167,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, frank.localpart, "Frankie") ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertIsInstance(result_token, str) self.assertGreater(len(result_token), 20) def test_if_user_exists(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main frank = UserID.from_string("@frank:test") self.get_success( store.register_user(user_id=frank.to_string(), password_hash=None) @@ -183,7 +183,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): result_user_id, result_token = self.get_success( self.get_or_create_user(requester, local_part, None) ) - self.assertEquals(result_user_id, user_id) + self.assertEqual(result_user_id, user_id) self.assertTrue(result_token is not None) @override_config({"limit_usage_by_mau": False}) @@ -760,7 +760,7 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_registration_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main @override_config({"auto_join_rooms": ["#room:remotetest"]}) def test_auto_create_auto_join_remote_room(self): diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 51b22d299..cff07a897 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -157,35 +157,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): state_key=room_id, ) - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - def _assert_hierarchy( self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: @@ -251,11 +222,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_simple_space(self): """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -271,12 +240,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, room, self.token) rooms.append(room) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The spaces result should have the space and the first 50 rooms in it, - # along with the links from space -> room for those 50 rooms. - expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]] - self._assert_rooms(result, expected) - # The result should have the space and the rooms in it, along with the links # from space -> room. expected = [(self.space, rooms)] + [(room, []) for room in rooms] @@ -300,10 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): token2 = self.login("user2", "pass") # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -316,7 +276,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"join_rule": JoinRules.INVITE}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -329,9 +288,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.WORLD_READABLE}, tok=self.token, ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) @@ -344,7 +300,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure( self.handler.get_room_hierarchy(create_requester(user2), self.space), AuthError, @@ -353,19 +308,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) self.get_failure( self.handler.get_room_hierarchy( create_requester(user2), "#not-a-space:" + self.hs.hostname @@ -496,7 +444,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space. self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) expected = [ ( self.space, @@ -520,7 +467,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(user2), self.space) @@ -554,8 +500,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(subspace, self.room, token=self.token) self._add_child(subspace, room2, self.token) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should include each room a single time and each link. expected = [ (self.space, [self.room, room2, subspace]), @@ -563,7 +507,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (subspace, [subroom, self.room, room2]), (subroom, ()), ] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -715,7 +658,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_unknown_room_version(self): """ - If an room with an unknown room version is encountered it should not cause + If a room with an unknown room version is encountered it should not cause the entire summary to skip. """ # Poke the database and update the room version to an unknown one. @@ -727,11 +670,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): desc="updated-room-version", ) ) + # Invalidate method so that it returns the currently updated version + # instead of the cached version. + self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,)) - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have only the space, along with a link from space -> room. expected = [(self.space, [self.room])] - self._assert_rooms(result, expected) result = self.get_success( self.handler.get_room_hierarchy(create_requester(self.user), self.space) @@ -775,41 +719,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "world_readable": True, } - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), (subspace, [subroom]), (subroom, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -881,7 +802,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], + "allowed_room_ids": [], }, ), ( @@ -890,7 +811,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): "room_id": restricted_accessible_room, "world_readable": False, "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], + "allowed_room_ids": [self.room], }, ), ( @@ -929,30 +850,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ], ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, subspace]), (self.room, ()), @@ -976,7 +879,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (world_readable_room, ()), (joined_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", @@ -1010,31 +912,17 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): }, ) - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - async def summarize_remote_room_hierarchy(_self, room, suggested_only): return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) - with mock.patch( - "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - expected = [ (self.space, [self.room, fed_room]), (self.room, ()), (fed_room, ()), ] - self._assert_rooms(result, expected) with mock.patch( "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e..23941abed 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -142,7 +142,7 @@ class SamlHandlerTestCase(HomeserverTestCase): @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): """Existing users can log in with SAML account.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) @@ -217,7 +217,7 @@ class SamlHandlerTestCase(HomeserverTestCase): sso_handler.render_error = Mock(return_value=None) # register a user to occupy the first-choice MXID - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 56207f4db..ecd78fa36 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = self.hs.get_stats_handler() def _add_background_updates(self): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 07a760e91..3aedc0767 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -69,7 +69,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False @@ -80,7 +80,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) - self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_unknown_room_version(self): """ @@ -122,7 +122,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): b"{}", tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # The rooms should appear in the sync response. result = self.get_success( @@ -248,7 +248,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. mocked_get_prev_events = patch.object( - self.hs.get_datastore(), + self.hs.get_datastores().main, "get_prev_events_for_room", new_callable=MagicMock, return_value=make_awaitable([last_room_creation_event_id]), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 000f9b9fd..f91a80b9f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -91,7 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources.typing - self.datastore = hs.get_datastore() + self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( return_value=defer.succeed(None) ) @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -169,13 +169,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -239,13 +239,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -259,7 +259,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_recv_not_in_room(self): self.room_members = [U_APPLE, U_ONION] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) channel = self.make_request( "PUT", @@ -278,7 +278,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_not_called() - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -288,8 +288,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals(events[0], []) - self.assertEquals(events[1], 0) + self.assertEqual(events[0], []) + self.assertEqual(events[1], 0) @override_config({"send_federation": True}) def test_stopped_typing(self): @@ -302,7 +302,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.handler._member_typing_until[member] = 1002000 self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.stopped_typing( @@ -332,13 +332,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): try_trailing_slash_on_400=True, ) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -346,7 +346,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_typing_timeout(self): self.room_members = [U_APPLE, U_BANANA] - self.assertEquals(self.event_source.get_current_key(), 0) + self.assertEqual(self.event_source.get_current_key(), 0) self.get_success( self.handler.started_typing( @@ -360,7 +360,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -370,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -385,7 +385,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -395,7 +395,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) @@ -414,7 +414,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) events = self.get_success( self.event_source.get_new_events( user=U_APPLE, @@ -424,7 +424,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 482c90ef6..92012cd6f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -77,7 +77,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.event_creation_handler = self.hs.get_event_creation_handler() @@ -1042,7 +1042,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing @@ -1053,5 +1053,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): b'{"search_term":"user2"}', access_token=u1_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index c49be33b9..77ce8432a 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -65,9 +65,9 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(test_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): @@ -88,8 +88,8 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_from_cache(self): @@ -114,8 +114,8 @@ class SrvResolverTestCase(unittest.TestCase): self.assertFalse(dns_client_mock.lookupService.called) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) @defer.inlineCallbacks def test_empty_cache(self): @@ -144,8 +144,8 @@ class SrvResolverTestCase(unittest.TestCase): servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) + self.assertEqual(len(servers), 0) + self.assertEqual(len(cache), 0) def test_disabled_service(self): """ @@ -201,6 +201,6 @@ class SrvResolverTestCase(unittest.TestCase): servers = self.successResultOf(resolve_d) - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, b"host") + self.assertEqual(len(servers), 1) + self.assertEqual(servers, cache[service_name]) + self.assertEqual(servers[0].host, b"host") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index d16cd141a..c3f20f969 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b64..7a3b0d675 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -102,13 +102,13 @@ class EmailPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token) + self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. self.get_success( - hs.get_datastore().user_add_threepid( + hs.get_datastores().main.user_add_threepid( self.user_id, "email", "a@example.com", 0, 0 ) ) @@ -128,7 +128,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -375,7 +375,7 @@ class EmailPusherTests(HomeserverTestCase): # check that the pusher for that email address has been deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -388,14 +388,14 @@ class EmailPusherTests(HomeserverTestCase): # This resembles the old behaviour, which the background update below is intended # to clean up. self.get_success( - self.hs.get_datastore().user_delete_threepid( + self.hs.get_datastores().main.user_delete_threepid( self.user_id, "email", "a@example.com" ) ) # Run the "remove_deleted_email_pushers" background job self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="background_updates", values={ "update_name": "remove_deleted_email_pushers", @@ -406,14 +406,14 @@ class EmailPusherTests(HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.hs.get_datastore().db_pool.updates._all_done = False + self.hs.get_datastores().main.db_pool.updates._all_done = False # Now let's actually drive the updates to completion self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 0) @@ -428,7 +428,7 @@ class EmailPusherTests(HomeserverTestCase): """ # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -439,7 +439,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -458,7 +458,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c068d329a..c284beb37 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -62,7 +62,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -108,7 +108,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -138,7 +138,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -149,7 +149,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -170,7 +170,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by({"user_name": user_id}) + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -224,7 +224,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -344,7 +344,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -430,7 +430,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -507,7 +507,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 - ) + self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self): @@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # - # We're counting every unread message, so there should now be 4 since the + # We're counting every unread message, so there should now be 3 since the # last read receipt - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 - ) + self._check_push_attempt(6, 3) def _test_push_unread_count(self): """ @@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase): Note that: * Sending messages will cause push notifications to go out to relevant users - * Sending a read receipt will cause a "badge update" notification to go out to - the user that sent the receipt + * Sending a read receipt will cause the HTTP pusher to check whether the unread + count has changed since the last push notification. If so, a "badge update" + notification goes out to the user that sent the receipt """ # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -616,7 +613,7 @@ class HTTPPusherTests(HomeserverTestCase): # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id @@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase): # position in the room. We'll set the read position to this event in a moment first_message_event_id = response["event_id"] - # Advance time a bit (so the pusher will register something has happened) and - # make the push succeed - self.push_attempts[0][0].callback({}) - self.pump() + expected_push_attempts = 1 + self._check_push_attempt(expected_push_attempts, 0) - # Check our push made it + self._send_read_request(access_token, first_message_event_id, room_id) + + # Unread count has not changed. Therefore, ensure that read request does not + # trigger a push notification. self.assertEqual(len(self.push_attempts), 1) - self.assertEqual( - self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" - ) + # Send another message + response2 = self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + second_message_event_id = response2["event_id"] + + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 1) + + self._send_read_request(access_token, second_message_event_id, room_id) + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 0) + + # If we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room. Otherwise, it + # should. Therefore, the last call to _check_push_attempt is done in the + # caller method. + self.helper.send(room_id, body="Hello?", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + def _advance_time_and_make_push_succeed(self, expected_push_attempts): + self.pump() + self.push_attempts[expected_push_attempts - 1][0].callback({}) + + def _check_push_attempt( + self, expected_push_attempts: int, expected_unread_count_last_push: int + ) -> None: + """ + Makes sure that the last expected push attempt succeeds and checks whether + it contains the expected unread count. + """ + self._advance_time_and_make_push_succeed(expected_push_attempts) + # Check our push made it + self.assertEqual(len(self.push_attempts), expected_push_attempts) + _, push_url, push_body = self.push_attempts[expected_push_attempts - 1] + self.assertEqual( + push_url, + "http://example.com/_matrix/push/v1/notify", + ) # Check that the unread count for the room is 0 # # The unread count is zero as the user has no read receipt in the room yet self.assertEqual( - self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + push_body["notification"]["counts"]["unread"], + expected_unread_count_last_push, ) + def _send_read_request(self, access_token, message_event_id, room_id): # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase): # count goes down channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id), {}, access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time and make the push succeed - self.push_attempts[1][0].callback({}) - self.pump() - - # Unread count is still zero as we've read the only message in the room - self.assertEqual(len(self.push_attempts), 2) - self.assertEqual( - self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 - ) - - # Send another message - self.helper.send( - room_id, body="How's the weather today?", tok=other_access_token - ) - - # Advance time and make the push succeed - self.push_attempts[2][0].callback({}) - self.pump() - - # This push should contain an unread count of 1 as there's now been one - # message since our last read receipt - self.assertEqual(len(self.push_attempts), 3) - self.assertEqual( - self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 - ) - - # Since we're grouping by room, sending more messages shouldn't increase the - # unread count, as they're all being sent in the same room - self.helper.send(room_id, body="Hello?", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[3][0].callback({}) - - self.helper.send(room_id, body="Hello??", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[4][0].callback({}) - - self.helper.send(room_id, body="HELLO???", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[5][0].callback({}) - - self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index a52e89e40..3849beb9d 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -14,6 +14,8 @@ from typing import Any, Dict +import frozendict + from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator @@ -191,6 +193,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should only match at the start/end of the value", ) + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "FoobaZ"}), + "patterns should match on frozendicts", + ) + # wildcards should match condition = { "kind": "event_match", diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9fc50f885..a7a05a564 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -68,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool + self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool # Normally we'd pass in the handler to `setup_test_homeserver`, which would # eventually hit "Install @cache_in_self attributes" in tests/utils.py. @@ -233,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # We may have an attempt to connect to redis for the external cache already. self.connect_any_redis_attempts() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -332,7 +332,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(worker_hs, port), ) - store = worker_hs.get_datastore() + store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool # Set up TCP replication between master and the new worker if we don't diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 83e89383f..85be79d19 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -30,8 +30,8 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.reconnect() - self.master_store = hs.get_datastore() - self.slaved_store = self.worker_hs.get_datastore() + self.master_store = hs.get_datastores().main + self.slaved_store = self.worker_hs.get_datastores().main self.storage = hs.get_storage() def replicate(self): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index eca6a443a..17dc42fd3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -59,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEquals + # whether lists of events match using assertEqual self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super().setUp() diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index cdd052001..50fbff5f3 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -23,7 +23,7 @@ from tests.replication._base import BaseStreamTestCase class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_room_account_data_limit(self): """Test replication with many room account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] @@ -69,7 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): def test_update_function_global_account_data_limit(self): """Test replication with many global account data updates""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # generate lots of account data updates updates = [] diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f198a9488..f9d5da723 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -136,7 +136,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events = [ @@ -291,7 +291,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # this is the point in the DAG where we make a fork fork_point: List[str] = self.get_success( - self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) events: List[EventBase] = [] diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 38e292c1a..eb0011784 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -32,7 +32,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) @@ -56,7 +56,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() self.get_success( - self.hs.get_datastore().insert_receipt( + self.hs.get_datastores().main.insert_receipt( "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92a5b53e1..ba1a63c0d 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -204,7 +204,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 4094a75f3..8f4f6688c 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -50,7 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): # Register a pusher user_dict = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_dict.token_id diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c..5f142e84c 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -47,7 +47,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_access_token = self.login("otheruser", "pass") self.room_creator = self.hs.get_room_creation_handler() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def default_config(self): conf = super().default_config() @@ -99,7 +99,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): persisted_on_1 = False persisted_on_2 = False - store = self.hs.get_datastore() + store = self.hs.get_datastores().main user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") @@ -166,7 +166,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): user_id = self.register_user("user", "pass") access_token = self.login("user", "pass") - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create two room on the different workers. self._create_room(room_id1, user_id, access_token) @@ -194,7 +194,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() + worker_store2 = worker_hs2.get_datastores().main assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) actx = worker_store2._stream_id_gen.get_next() diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 1e3fe9c62..fb36aa994 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 71068d16c..929bbdc37 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 86aff7575..0d47dd0af 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 8513b1d2d..8354250ec 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 23da0ad73..95282f078 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1909,7 +1909,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1957,7 +1957,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1980,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2010,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2044,7 +2044,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2074,8 +2074,8 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEquals( + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self._store = hs.get_datastore() + self._store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f76..2c855bff9 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() self.server_notices_manager = self.hs.get_server_notices_manager() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 272637e96..a60ea0a56 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): even if the MAU limit is reached. """ handler = self.hs.get_registration_handler() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Set monthly active users to the limit store.get_monthly_active_count = Mock( @@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() # create users and get access tokens @@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) @@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 89d85b0a1..6c4462e74 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1,6 +1,4 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,16 +15,22 @@ import json import os import re from email.parser import Parser -from typing import Optional +from typing import Dict, List, Optional +from unittest.mock import Mock import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -73,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): @@ -100,7 +104,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -139,7 +143,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -189,7 +193,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email_passwort_reset, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -226,7 +230,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to reset password without clicking the link self._reset_password(new_password, session_id, client_secret, expected_code=401) @@ -318,7 +322,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -333,7 +337,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -372,7 +376,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): }, }, ) - self.assertEquals(expected_code, channel.code, channel.result) + self.assertEqual(expected_code, channel.code, channel.result) class DeactivateTestCase(unittest.HomeserverTestCase): @@ -394,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(user_id, tok) - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Check that the user has been marked as deactivated. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) @@ -405,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main inviter_id = self.register_user("inviter", "test") inviter_tok = self.login("inviter", "test") @@ -486,8 +490,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) @@ -505,8 +510,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": True, + "is_guest": True, }, ) @@ -521,15 +527,16 @@ class WhoamiTestCase(unittest.HomeserverTestCase): namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) whoami = self._whoami(as_token) self.assertEqual( whoami, { "user_id": user_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) self.assertFalse(hasattr(whoami, "device_id")) @@ -579,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return self.hs def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") self.user_id_tok = self.login("kermit", "test") @@ -669,7 +676,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) link = self._get_link_from_email() self._validate_token(link) @@ -773,7 +780,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(self.email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEqual(len(self.email_attempts), 1) # Attempt to add email without clicking the link channel = self.make_request( @@ -974,7 +981,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) def _get_link_from_email(self): assert self.email_attempts, "No emails have been sent" @@ -1003,7 +1010,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -1037,3 +1044,197 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} self.assertIn(expected_email, threepids) + + +class AccountStatusTestCase(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["experimental_features"] = {"msc3720_enabled": True} + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.requester = self.register_user("requester", "password") + self.requester_tok = self.login("requester", "password") + self.server_name = homeserver.config.server.server_name + + def test_missing_mxid(self): + """Tests that not providing any MXID raises an error.""" + self._test_status( + users=None, + expected_status_code=400, + expected_errcode=Codes.MISSING_PARAM, + ) + + def test_invalid_mxid(self): + """Tests that providing an invalid MXID raises an error.""" + self._test_status( + users=["bad:test"], + expected_status_code=400, + expected_errcode=Codes.INVALID_PARAM, + ) + + def test_local_user_not_exists(self): + """Tests that the account status endpoints correctly reports that a user doesn't + exist. + """ + user = "@unknown:" + self.hs.config.server.server_name + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_exists(self): + """Tests that the account status endpoint correctly reports that a user doesn't + exist. + """ + user = self.register_user("someuser", "password") + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": False, + }, + }, + expected_failures=[], + ) + + def test_local_user_deactivated(self): + """Tests that the account status endpoint correctly reports a deactivated user.""" + user = self.register_user("someuser", "password") + self.get_success( + self.hs.get_datastores().main.set_user_deactivated_status( + user, deactivated=True + ) + ) + + self._test_status( + users=[user], + expected_statuses={ + user: { + "exists": True, + "deactivated": True, + }, + }, + expected_failures=[], + ) + + def test_mixed_local_and_remote_users(self): + """Tests that if some users are remote the account status endpoint correctly + merges the remote responses with the local result. + """ + # We use 3 users: one doesn't exist but belongs on the local homeserver, one is + # deactivated and belongs on one remote homeserver, and one belongs to another + # remote homeserver that didn't return any result (the federation code should + # mark that user as a failure). + users = [ + "@unknown:" + self.hs.config.server.server_name, + "@deactivated:remote", + "@failed:otherremote", + "@bad:badremote", + ] + + async def post_json(destination, path, data, *a, **kwa): + if destination == "remote": + return { + "account_statuses": { + users[1]: { + "exists": True, + "deactivated": True, + }, + } + } + if destination == "otherremote": + return {} + if destination == "badremote": + # badremote tries to overwrite the status of a user that doesn't belong + # to it (i.e. users[1]) with false data, which Synapse is expected to + # ignore. + return { + "account_statuses": { + users[3]: { + "exists": False, + }, + users[1]: { + "exists": False, + }, + } + } + + # Register a mock that will return the expected result depending on the remote. + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + + # Check that we've got the correct response from the client-side endpoint. + self._test_status( + users=users, + expected_statuses={ + users[0]: { + "exists": False, + }, + users[1]: { + "exists": True, + "deactivated": True, + }, + users[3]: { + "exists": False, + }, + }, + expected_failures=[users[2]], + ) + + def _test_status( + self, + users: Optional[List[str]], + expected_status_code: int = 200, + expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, + expected_failures: Optional[List[str]] = None, + expected_errcode: Optional[str] = None, + ): + """Send a request to the account status endpoint and check that the response + matches with what's expected. + + Args: + users: The account(s) to request the status of, if any. If set to None, no + `user_id` query parameter will be included in the request. + expected_status_code: The expected HTTP status code. + expected_statuses: The expected account statuses, if any. + expected_failures: The expected failures, if any. + expected_errcode: The expected Matrix error code, if any. + """ + content = {} + if users is not None: + content["user_ids"] = users + + channel = self.make_request( + method="POST", + path=self.url, + content=content, + access_token=self.requester_tok, + ) + + self.assertEqual(channel.code, expected_status_code) + + if expected_statuses is not None: + self.assertEqual(channel.json_body["account_statuses"], expected_statuses) + + if expected_failures is not None: + self.assertEqual(channel.json_body["failures"], expected_failures) + + if expected_errcode is not None: + self.assertEqual(channel.json_body["errcode"], expected_errcode) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 27cb856b0..9653f4583 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -32,11 +37,11 @@ from tests.unittest import override_config, skip_unless class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): + def __init__(self, hs: HomeServer) -> None: super().__init__(hs) - self.recaptcha_attempts = [] + self.recaptcha_attempts: List[Tuple[dict, str]] = [] - def check_auth(self, authdict, clientip): + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -49,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ] hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() @@ -60,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker @@ -100,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - def test_fallback_captcha(self): + def test_fallback_captcha(self) -> None: """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( @@ -131,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): + def test_complete_operation_unknown_session(self) -> None: """ Attempting to mark an invalid session as complete should error. """ @@ -164,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns @@ -181,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase): return config - def create_resource_dict(self): + def create_resource_dict(self) -> Dict[str, Resource]: resource_dict = super().create_resource_dict() resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.device_id = "dev1" @@ -228,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase): return channel - def test_ui_auth(self): + def test_ui_auth(self) -> None: """ Test user interactive authentication outside of registration. """ @@ -258,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_grandfathered_identifier(self): + def test_grandfathered_identifier(self) -> None: """Check behaviour without "identifier" dict Synapse used to require clients to submit a "user" field for m.login.password @@ -285,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_can_change_body(self): + def test_can_change_body(self) -> None: """ The client dict can be modified during the user interactive authentication session. @@ -324,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_cannot_change_uri(self): + def test_cannot_change_uri(self) -> None: """ The initial requested URI cannot be modified during the user interactive authentication session. """ @@ -361,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase): ) @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): + def test_can_reuse_session(self) -> None: """ The session can be reused if configured. @@ -408,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): + def test_ui_auth_via_sso(self) -> None: """Test a successful UI Auth flow via SSO This includes: @@ -451,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): + def test_does_not_offer_password_for_sso_user(self) -> None: login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -463,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - def test_does_not_offer_sso_for_password_user(self): + def test_does_not_offer_sso_for_password_user(self) -> None: channel = self.delete_device( self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED ) @@ -473,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): + def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -490,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): + def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) @@ -527,12 +532,13 @@ class RefreshAuthTests(unittest.HomeserverTestCase): auth.register_servlets, account.register_servlets, login.register_servlets, + logout.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) @@ -546,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token) -> bool: + def is_access_token_valid(self, access_token: str) -> bool: """ Checks whether an access token is valid, returning whether it is or not. """ @@ -559,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): + def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. """ @@ -589,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) - def test_register_issue_refresh_token(self): + def test_register_issue_refresh_token(self) -> None: """ A register response should include a refresh_token only if asked. """ @@ -625,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) - def test_token_refresh(self): + def test_token_refresh(self) -> None: """ A refresh token can be used to issue a new access token. """ @@ -663,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refreshable_access_token_expiration(self) -> None: """ The access token should have some time as specified in the config. """ @@ -720,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "nonrefreshable_access_token_lifetime": "10m", } ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens( + self, + ) -> None: """ Tests that the expiry times for refreshable and non-refreshable access tokens can be different. @@ -780,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) - def test_refresh_token_expiry(self): + def test_refresh_token_expiry(self) -> None: """ The refresh token can be configured to have a limited lifetime. When that lifetime has ended, the refresh token can no longer be used to @@ -832,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "session_lifetime": "3m", } ) - def test_ultimate_session_expiry(self): + def test_ultimate_session_expiry(self) -> None: """ The session can be configured to have an ultimate, limited lifetime. """ @@ -880,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result ) - def test_refresh_token_invalidation(self): + def test_refresh_token_invalidation(self) -> None: """Refresh tokens are invalidated after first use of the next token. A refresh token is considered invalid if: @@ -984,3 +992,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertEqual( fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) + + def test_many_token_refresh(self) -> None: + """ + If a refresh is performed many times during a session, there shouldn't be + extra 'cruft' built up over time. + + This test was written specifically to troubleshoot a case where logout + was very slow if a lot of refreshes had been performed for the session. + """ + + def _refresh(refresh_token: str) -> Tuple[str, str]: + """ + Performs one refresh, returning the next refresh token and access token. + """ + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual( + refresh_response.code, HTTPStatus.OK, refresh_response.result + ) + return ( + refresh_response.json_body["refresh_token"], + refresh_response.json_body["access_token"], + ) + + def _table_length(table_name: str) -> int: + """ + Helper to get the size of a table, in rows. + For testing only; trivially vulnerable to SQL injection. + """ + + def _txn(txn: LoggingTransaction) -> int: + txn.execute(f"SELECT COUNT(1) FROM {table_name}") + row = txn.fetchone() + # Query is infallible + assert row is not None + return row[0] + + return self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "_table_length", _txn + ) + ) + + # Before we log in, there are no access tokens. + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + + access_token = login_response.json_body["access_token"] + refresh_token = login_response.json_body["refresh_token"] + + # Now that we have logged in, there should be one access token and one + # refresh token + self.assertEqual(_table_length("access_tokens"), 1) + self.assertEqual(_table_length("refresh_tokens"), 1) + + for _ in range(5): + refresh_token, access_token = _refresh(refresh_token) + + # After 5 sequential refreshes, there should only be the latest two + # refresh/access token pairs. + # (The last one is preserved because it's in use! + # The one before that is preserved because it can still be used to + # replace the last token pair, in case of e.g. a network interruption.) + self.assertEqual(_table_length("access_tokens"), 2) + self.assertEqual(_table_length("refresh_tokens"), 2) + + logout_response = self.make_request( + "POST", "/_matrix/client/v3/logout", {}, access_token=access_token + ) + self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result) + + # Now that we have logged in, there should be no access token + # and no refresh token + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 989e80176..d1751e155 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -13,9 +13,13 @@ # limitations under the License. from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.localpart = "user" self.password = "pass" self.user = self.register_user(self.localpart, self.password) - def test_check_auth_required(self): + def test_check_auth_required(self) -> None: channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) - def test_get_room_version_capabilities(self): + def test_get_room_version_capabilities(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities_password_login(self): + def test_get_change_password_capabilities_password_login(self) -> None: access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) @@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): + def test_get_change_password_capabilities_localdb_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.change_password"]["enabled"]) @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): + def test_get_change_password_capabilities_password_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities(self): + def test_get_change_users_attributes_capabilities(self) -> None: """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) @@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) @override_config({"enable_set_displayname": False}) - def test_get_set_displayname_capabilities_displayname_disabled(self): + def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_displayname"]["enabled"]) @override_config({"enable_set_avatar_url": False}) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) @override_config({"enable_3pid_changes": False}) - def test_get_change_3pid_capabilities_3pid_disabled(self): + def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) - def test_get_does_not_include_msc3244_fields_when_disabled(self): + def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None @@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - def test_get_does_include_msc3244_fields_when_enabled(self): + def test_get_does_include_msc3244_fields_when_enabled(self) -> None: access_token = self.get_success( self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index fcdc56581..b1ca81a91 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -13,11 +13,16 @@ # limitations under the License. import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder from synapse.rest.client import login, room from synapse.rest.consent import consent_resource +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["form_secret"] = "123abc" @@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_render_public_consent(self): + def test_render_public_consent(self) -> None: """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( @@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): "/consent?v=1", shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) - def test_accept_consent(self): + def test_accept_consent(self) -> None: """ A user can use the consent form to accept the terms. """ @@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and whether we've consented version, consented = channel.result["body"].decode("ascii").split(",") @@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Fetch the consent page, to get the consent version -- it should have # changed @@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): access_token=access_token, shorthand=False, ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py new file mode 100644 index 000000000..a8af4e243 --- /dev/null +++ b/tests/rest/client/test_device_lists.py @@ -0,0 +1,159 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus + +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 3d7aa8ec8..9fa1f82df 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -11,9 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_ephemeral_messages"] = True @@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) - def test_message_expiry_no_delay(self): + def test_message_expiry_no_delay(self) -> None: """Tests that sending a message sent with a m.self_destruct_after field set to the past results in that event being deleted right away. """ @@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def test_message_expiry_delay(self): + def test_message_expiry_delay(self) -> None: """Tests that sending a message with a m.self_destruct_after field set to the future results in that event not being deleted right away, but advancing the clock to after that expiry timestamp causes the event to be deleted. @@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): event_content = self.get_event(self.room_id, event_id)["content"] self.assertFalse(bool(event_content), event_content) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a90294003..1b1392fa2 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -16,8 +16,12 @@ from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_registration_captcha"] = False @@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() + hs.get_federation_handler = Mock() # type: ignore[assignment] return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): self.other_user = self.register_user("other2", "pass") self.other_token = self.login(self.other_user, "pass") - def test_stream_basic_permissions(self): + def test_stream_basic_permissions(self) -> None: # invalid token, expect 401 # note: this is in violation of the original v1 spec, which expected # 403. However, since the v1 spec no longer exists and the v1 @@ -65,18 +69,18 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.assertEquals(channel.code, 401, msg=channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) # valid token, expect content channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_stream_room_permissions(self): + def test_stream_room_permissions(self) -> None: room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) self.helper.send(room_id, tok=self.other_token) @@ -89,10 +93,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down - self.assertEquals( + self.assertEqual( 0, len( [ @@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # left to room (expect no content for room) - def TODO_test_stream_items(self): + def TODO_test_stream_items(self) -> None: # new user, no content # join room, expect 1 item (join) @@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, hs, reactor, clock): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register an account self.user_id = self.register_user("sid1", "pass") @@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - def test_get_event_via_events(self): + def test_get_event_via_events(self) -> None: resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] @@ -153,4 +157,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase): "/events/" + event_id, access_token=self.token, ) - self.assertEquals(channel.code, 200, msg=channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 475c6bed3..5c31a5442 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_add_filter(self): channel = self.make_request( @@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) + self.assertEqual(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): channel = self.make_request( @@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine @@ -68,7 +68,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): filter_id = defer.ensureDeferred( @@ -83,7 +83,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): channel = self.make_request( @@ -91,7 +91,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py index ad0425ae6..e067cf825 100644 --- a/tests/rest/client/test_groups.py +++ b/tests/rest/client/test_groups.py @@ -25,13 +25,13 @@ class GroupsTestCase(unittest.HomeserverTestCase): servlets = [room.register_servlets, groups.register_servlets] @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self): + def test_rooms_limited_by_visibility(self) -> None: group_id = "+spqr:test" # Alice creates a group channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {"group_id": group_id}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {"group_id": group_id}) # Bob creates a private room room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) @@ -45,12 +45,12 @@ class GroupsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} ) - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals(channel.json_body, {}) + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual(channel.json_body, {}) # Alice now tries to retrieve the room list of the space. channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEquals(channel.code, 200, msg=channel.text_body) - self.assertEquals( + self.assertEqual(channel.code, 200, msg=channel.text_body) + self.assertEqual( channel.json_body, {"chunk": [], "total_room_count_estimate": 0} ) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index becb4e8dc..299b9d21e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -13,9 +13,14 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["enable_3pid_lookup"] = False @@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs - def test_3pid_lookup_disabled(self): + def test_3pid_lookup_disabled(self) -> None: self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] params = { @@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index d7fa635ea..bbc8e7424 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def test_rejects_device_id_ice_key_outside_of_list(self): + def test_rejects_device_id_ice_key_outside_of_list(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_rejects_device_key_given_as_map_to_bool(self): + def test_rejects_device_key_given_as_map_to_bool(self) -> None: self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") bob = self.register_user("bob", "uncle") @@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_requires_device_key(self): + def test_requires_device_key(self) -> None: """`device_keys` is required. We should complain if it's missing.""" self.register_user("alice", "wonderland") alice_token = self.login("alice", "wonderland") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 19f5e4653..090d2d0a2 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,6 +20,7 @@ from urllib.parse import urlencode import pymacaroons +from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin @@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.hs.config.registration.enable_registration = True self.hs.config.registration.registrations_require_3pid = [] @@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_address(self): + def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -132,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -150,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account(self): + def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -177,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -195,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config( { @@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } } ) - def test_POST_ratelimiting_per_account_failed_attempts(self): + def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") for i in range(0, 6): @@ -222,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -240,16 +244,16 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): + def test_soft_logout(self) -> None: self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal params = { @@ -259,22 +263,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # # test behaviour after deleting the expired device @@ -286,24 +290,26 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], False) - def _delete_device(self, access_token, user_id, password, device_id): + def _delete_device( + self, access_token: str, user_id: str, password: str, device_id: str + ) -> None: """Perform the UI-Auth to delete a device""" channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEquals(channel.code, 401, channel.result) + self.assertEqual(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -326,10 +332,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): + def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -337,23 +343,25 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( + self, + ) -> None: self.register_user("kermit", "monkey") # log in as normal @@ -361,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) + self.assertEqual(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) + self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEqual(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") @@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_get_login_flows(self): + def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) @@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ], ) - def test_multi_sso_redirect(self): + def test_multi_sso_redirect(self) -> None: """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) @@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - def test_multi_sso_redirect_to_cas(self): + def test_multi_sso_redirect_to_cas(self) -> None: """If CAS is chosen, should redirect to the CAS server""" channel = self.make_request( @@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri_params = urllib.parse.parse_qs(service_uri_query) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_saml(self): + def test_multi_sso_redirect_to_saml(self) -> None: """If SAML is chosen, should redirect to the SAML server""" channel = self.make_request( "GET", @@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): relay_state_param = saml_uri_params["RelayState"][0] self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_login_via_oidc(self): + def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" # pick the default OIDC provider @@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") - def test_multi_sso_redirect_to_unknown(self): + def test_multi_sso_redirect_to_unknown(self) -> None: """An unknown IdP should cause a 400""" channel = self.make_request( "GET", @@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_to_unknown(self): + def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - def test_client_idp_redirect_to_oidc(self): + def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): + def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.base_url = "https://matrix.goodserver.com/" self.redirect_path = "_synapse/client/login/sso/redirect/confirm" @@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_user_id = "username" self.user_id = "@%s:test" % cas_user_id - async def get_raw(uri, args): + async def get_raw(uri: str, args: Any) -> bytes: """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 @@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.deactivate_account_handler = hs.get_deactivate_account_handler() - def test_cas_redirect_confirm(self): + def test_cas_redirect_confirm(self) -> None: """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ @@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase): } } ) - def test_cas_redirect_whitelisted(self): + def test_cas_redirect_whitelisted(self) -> None: """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): + def test_cas_redirect_login_fallback(self) -> None: self._test_redirect("https://example.com/_matrix/static/client/login") - def _test_redirect(self, redirect_url): + def _test_redirect(self, redirect_url: str) -> None: """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" @@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: """Logging in as a deactivated account should error.""" redirect_url = "https://legit-site.com/" @@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "algorithm": jwt_algorithm, } - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # If jwt_config has been defined (eg via @override_config), don't replace it. @@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid_registered(self): + def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_valid_unregistered(self): + def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: Signature verification failed", ) - def test_login_jwt_expired(self): + def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature has expired" ) - def test_login_jwt_not_before(self): + def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase): "JWT validation failed: The token is not yet valid (nbf)", ) - def test_login_no_sub(self): + def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) - def test_login_iss(self): + def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) @@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "iss" claim', ) - def test_login_iss_no_config(self): + def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) - def test_login_aud(self): + def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) @@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase): 'JWT validation failed: Token is missing the "aud" claim', ) - def test_login_aud_no_config(self): + def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) - def test_login_default_sub(self): + def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) - def test_login_custom_sub(self): + def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") - def test_login_no_token(self): + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["jwt_config"] = { "enabled": True, @@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return result.decode("ascii") return result - def jwt_login(self, *args): + def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel - def test_login_jwt_valid(self): + def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - def test_login_jwt_invalid_signature(self): + def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() self.service = ApplicationService( @@ -1101,11 +1113,11 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): }, ) - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) + self.hs.get_datastores().main.services_cache.append(self.service) + self.hs.get_datastores().main.services_cache.append(self.another_service) return self.hs - def test_login_appservice_user(self): + def test_login_appservice_user(self) -> None: """Test that an appservice user can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1117,9 +1129,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_login_appservice_user_bot(self): + def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" self.register_appservice_user(AS_USER, self.service.token) @@ -1131,9 +1143,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_login_appservice_wrong_user(self): + def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1145,9 +1157,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) - def test_login_appservice_wrong_as(self): + def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" self.register_appservice_user(AS_USER, self.service.token) @@ -1159,9 +1171,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) - def test_login_appservice_no_token(self): + def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice login method """ @@ -1173,7 +1185,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) @skip_unless(HAS_OIDC, "requires OIDC") @@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): servlets = [login.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self): + def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" # do the start of the login flow diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3cf587189..3a74d2e96 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -13,11 +13,16 @@ # limitations under the License. import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import account, login, password_policy, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.register_url = "/_matrix/client/r0/register" self.policy = { "enabled": True, @@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) return hs - def test_get_policy(self): + def test_get_policy(self) -> None: """Tests if the /password_policy endpoint returns the configured policy.""" channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual( channel.json_body, { @@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): channel.result, ) - def test_password_too_short(self): + def test_password_too_short(self) -> None: request_data = json.dumps({"username": "kermit", "password": "shorty"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, ) - def test_password_no_digit(self): + def test_password_no_digit(self) -> None: request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, ) - def test_password_no_symbol(self): + def test_password_no_symbol(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, ) - def test_password_no_uppercase(self): + def test_password_no_uppercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, ) - def test_password_no_lowercase(self): + def test_password_no_lowercase(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) channel = self.make_request("POST", self.register_url, request_data) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, ) - def test_password_compliant(self): + def test_password_compliant(self) -> None: request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) - def test_password_change(self): + def test_password_change(self) -> None: """This doesn't test every possible use case, only that hitting /account/password triggers the password validation code. """ @@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): access_token=tok, ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index c0de4c93a..27dcfc83d 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a room admin, moderator and regular user self.admin_user_id = self.register_user("admin", "pass") self.admin_access_token = self.login("admin", "pass") @@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, ) - def test_non_admins_cannot_enable_room_encryption(self): + def test_non_admins_cannot_enable_room_encryption(self) -> None: # have the mod try to enable room encryption self.helper.send_state( self.room_id, @@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_send_server_acl(self): + def test_non_admins_cannot_send_server_acl(self) -> None: # have the mod try to send a server ACL self.helper.send_state( self.room_id, @@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a server ACL @@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.user_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) - def test_non_admins_cannot_tombstone_room(self): + def test_non_admins_cannot_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.mod_access_token, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure ) # have the user try to send a tombstone event @@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase): expect_code=403, # expect failure ) - def test_admins_can_enable_room_encryption(self): + def test_admins_can_enable_room_encryption(self) -> None: # have the admin try to enable room encryption self.helper.send_state( self.room_id, "m.room.encryption", {"algorithm": "m.megolm.v1.aes-sha2"}, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_send_server_acl(self): + def test_admins_can_send_server_acl(self) -> None: # have the admin try to send a server ACL self.helper.send_state( self.room_id, @@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "deny": ["*.evil.com", "evil.com"], }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_admins_can_tombstone_room(self): + def test_admins_can_tombstone_room(self) -> None: # Create another room that will serve as our "upgraded room" self.upgraded_room_id = self.helper.create_room_as( self.admin_user_id, tok=self.admin_access_token @@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase): "replacement_room": self.upgraded_room_id, }, tok=self.admin_access_token, - expect_code=200, # expect success + expect_code=HTTPStatus.OK, # expect success ) - def test_cannot_set_string_power_levels(self): + def test_cannot_set_string_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_large_power_levels(self): + def test_cannot_set_unsafe_large_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( @@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase): body, ) - def test_cannot_set_unsafe_small_power_levels(self): + def test_cannot_set_unsafe_small_power_levels(self) -> None: room_power_levels = self.helper.get_state( self.room_id, "m.room.power_levels", @@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase): "m.room.power_levels", room_power_levels, tok=self.admin_access_token, - expect_code=400, # expect failure + expect_code=HTTPStatus.BAD_REQUEST, # expect failure ) self.assertEqual( diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 56fe1a3d0..0abe378fe 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): user = UserID.from_string(user_id) servlets = [presence.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) @@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): return hs - def test_put_presence(self): + def test_put_presence(self) -> None: """ PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. @@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): + def test_put_presence_disabled(self) -> None: """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. @@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase): "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index ead883ded..77c3ced42 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,12 +13,16 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from typing import Any, Dict +from typing import Any, Dict, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest @@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") self.other = self.register_user("other", "pass", displayname="Bob") - def test_get_displayname(self): + def test_get_displayname(self) -> None: res = self._get_displayname() self.assertEqual(res, "owner") - def test_set_displayname(self): + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -57,7 +61,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "test") - def test_set_displayname_noauth(self): + def test_set_displayname_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), @@ -65,7 +69,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_displayname_too_long(self): + def test_set_displayname_too_long(self) -> None: """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", @@ -78,11 +82,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") - def test_get_displayname_other(self): + def test_get_displayname_other(self) -> None: res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") + self.assertEqual(res, "Bob") - def test_set_displayname_other(self): + def test_set_displayname_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.other,), @@ -91,11 +95,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_get_avatar_url(self): + def test_get_avatar_url(self) -> None: res = self._get_avatar_url() self.assertIsNone(res) - def test_set_avatar_url(self): + def test_set_avatar_url(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -107,7 +111,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertEqual(res, "http://my.server/pic.gif") - def test_set_avatar_url_noauth(self): + def test_set_avatar_url_noauth(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.owner,), @@ -115,7 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 401, channel.result) - def test_set_avatar_url_too_long(self): + def test_set_avatar_url_too_long(self) -> None: """Attempts to set a stupid avatar_url should get a 400""" channel = self.make_request( "PUT", @@ -128,11 +132,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_avatar_url() self.assertIsNone(res) - def test_get_avatar_url_other(self): + def test_get_avatar_url_other(self) -> None: res = self._get_avatar_url(self.other) self.assertIsNone(res) - def test_set_avatar_url_other(self): + def test_set_avatar_url_other(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.other,), @@ -141,14 +145,14 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name=None): + def _get_displayname(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] - def _get_avatar_url(self, name=None): + def _get_avatar_url(self, name: Optional[str] = None) -> str: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) @@ -156,7 +160,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_global(self): + def test_avatar_size_limit_global(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a global profile. """ @@ -187,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_size_limit_per_room(self): + def test_avatar_size_limit_per_room(self) -> None: """Tests that the maximum size limit for avatars is enforced when updating a per-room profile. """ @@ -220,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_global(self): + def test_avatar_allowed_mime_type_global(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a global profile. """ @@ -251,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_allowed_mime_type_per_room(self): + def test_avatar_allowed_mime_type_per_room(self) -> None: """Tests that the MIME type whitelist for avatars is enforced when updating a per-room profile. """ @@ -283,7 +287,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: """Stores metadata about files in the database. Args: @@ -292,7 +296,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): properties are "mimetype" (for the file's type) and "size" (for the file's size). """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main for name, props in names_and_props.items(): self.get_success( @@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -325,7 +328,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User owning the requested profile. self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") @@ -337,22 +340,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - def test_no_auth(self): + def test_no_auth(self) -> None: self.try_fetch_profile(401) - def test_not_in_shared_room(self): + def test_not_in_shared_room(self) -> None: self.ensure_requester_left_room() self.try_fetch_profile(403, access_token=self.requester_tok) - def test_in_shared_room(self): + def test_in_shared_room(self) -> None: self.ensure_requester_left_room() self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) - def try_fetch_profile(self, expected_code, access_token=None): + def try_fetch_profile( + self, expected_code: int, access_token: Optional[str] = None + ) -> None: self.request_profile(expected_code, access_token=access_token) self.request_profile( @@ -363,13 +368,18 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): expected_code, url_suffix="/avatar_url", access_token=access_token ) - def request_profile(self, expected_code, url_suffix="", access_token=None): + def request_profile( + self, + expected_code: int, + url_suffix: str = "", + access_token: Optional[str] = None, + ) -> None: channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) - def ensure_requester_left_room(self): + def ensure_requester_left_room(self) -> None: try: self.helper.leave( room=self.room_id, user=self.requester, tok=self.requester_tok @@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_auth_for_profile_requests"] = True config["limit_profile_requests_to_users_who_share_rooms"] = True @@ -397,12 +407,12 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # User requesting the profile. self.requester = self.register_user("requester", "pass") self.requester_tok = self.login("requester", "pass") - def test_can_lookup_own_profile(self): + def test_can_lookup_own_profile(self) -> None: """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index d0ce91ccd..4f875b928 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): ] hijack_auth = False - def test_enabled_on_creation(self): + def test_enabled_on_creation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even though a rule with that @@ -56,7 +56,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_on_recreation(self): + def test_enabled_on_recreation(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is enabled upon creation, even if a rule with that @@ -113,7 +113,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_disable(self): + def test_enabled_disable(self) -> None: """ Tests the GET and PUT of push rules' `enabled` endpoints. Tests that a rule is disabled and enabled when we ask for it. @@ -166,7 +166,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) - def test_enabled_404_when_get_non_existent(self): + def test_enabled_404_when_get_non_existent(self) -> None: """ Tests that `enabled` gives 404 when the rule doesn't exist. """ @@ -212,7 +212,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_get_non_existent_server_rule(self): + def test_enabled_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when the server-default rule doesn't exist. """ @@ -226,7 +226,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_rule(self): + def test_enabled_404_when_put_non_existent_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a rule that doesn't exist. """ @@ -243,7 +243,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_enabled_404_when_put_non_existent_server_rule(self): + def test_enabled_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. """ @@ -260,7 +260,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_get(self): + def test_actions_get(self) -> None: """ Tests that `actions` gives you what you expect on a fresh rule. """ @@ -289,7 +289,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] ) - def test_actions_put(self): + def test_actions_put(self) -> None: """ Tests that PUT on actions updates the value you'd get from GET. """ @@ -325,7 +325,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - def test_actions_404_when_get_non_existent(self): + def test_actions_404_when_get_non_existent(self) -> None: """ Tests that `actions` gives 404 when the rule doesn't exist. """ @@ -365,7 +365,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_get_non_existent_server_rule(self): + def test_actions_404_when_get_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when the server-default rule doesn't exist. """ @@ -379,7 +379,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_rule(self): + def test_actions_404_when_put_non_existent_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a rule that doesn't exist. """ @@ -396,7 +396,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_actions_404_when_put_non_existent_server_rule(self): + def test_actions_404_when_put_non_existent_server_rule(self) -> None: """ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. """ diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 433d715f6..7401b5e0c 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,9 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["rc_message"] = {"per_second": 0.2, "burst_count": 10} @@ -36,7 +42,7 @@ class RedactionsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # register a couple of users self.mod_user_id = self.register_user("user1", "pass") self.mod_access_token = self.login("user1", "pass") @@ -60,7 +66,9 @@ class RedactionsTestCase(HomeserverTestCase): room=self.room_id, user=self.other_user_id, tok=self.other_access_token ) - def _redact_event(self, access_token, room_id, event_id, expect_code=200): + def _redact_event( + self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + ) -> JsonDict: """Helper function to send a redaction event. Returns the json body. @@ -71,13 +79,13 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body - def _sync_room_timeline(self, access_token, room_id): + def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] - def test_redact_event_as_moderator(self): + def test_redact_event_as_moderator(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -98,7 +106,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_event_as_normal(self): + def test_redact_event_as_normal(self) -> None: # as a regular user, send a message to redact b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) normal_msg_id = b["event_id"] @@ -133,7 +141,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-3]["content"], {}) - def test_redact_nonexistent_event(self): + def test_redact_nonexistent_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) msg_id = b["event_id"] @@ -158,7 +166,7 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) - def test_redact_create_event(self): + def test_redact_create_event(self) -> None: # control case: an existing event b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) msg_id = b["event_id"] @@ -178,7 +186,7 @@ class RedactionsTestCase(HomeserverTestCase): self.other_access_token, self.room_id, create_event_id, expect_code=403 ) - def test_redact_event_as_moderator_ratelimit(self): + def test_redact_event_as_moderator_ratelimit(self) -> None: """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 0f1c47dcb..9aebf1735 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -16,15 +16,21 @@ import datetime import json import os +from typing import Any, Dict, List, Tuple import pkg_resources +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.server import HomeServer from synapse.storage._base import db_to_json +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ] url = b"/_matrix/client/r0/register" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_POST_appservice_registration_valid(self): + def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" @@ -56,7 +62,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps( {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} ) @@ -65,11 +71,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_appservice_registration_no_type(self): + def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" appservice = ApplicationService( @@ -80,16 +86,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): sender="@as:test", ) - self.hs.get_datastore().services_cache.append(appservice) + self.hs.get_datastores().main.services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.result["code"], b"400", channel.result) - def test_POST_appservice_registration_invalid(self): + def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists request_data = json.dumps( {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} @@ -98,23 +104,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) - def test_POST_bad_password(self): + def test_POST_bad_password(self) -> None: request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): + def test_POST_bad_username(self) -> None: request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") + self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): + def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" params = { @@ -131,58 +137,58 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): + def test_POST_disabled_registration(self) -> None: request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Registration has been disabled") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - def test_POST_guest_registration(self): + def test_POST_guest_registration(self) -> None: self.hs.config.key.macaroon_secret_key = "test" self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - def test_POST_disabled_guest_registration(self): + def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): + def test_POST_ratelimiting_guest(self) -> None: for i in range(0, 6): url = self.url + b"?kind=guest" channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): + def test_POST_ratelimiting(self) -> None: for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -194,23 +200,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) @override_config({"registration_requires_token": True}) - def test_POST_registration_requires_token(self): + def test_POST_registration_requires_token(self) -> None: username = "kermit" device_id = "frogfone" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = { + params: JsonDict = { "username": username, "password": "monkey", "device_id": device_id, @@ -231,7 +237,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -249,7 +255,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -265,7 +271,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -276,12 +282,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["completed"], 1) - self.assertEquals(res["pending"], 0) + self.assertEqual(res["completed"], 1) + self.assertEqual(res["pending"], 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_invalid(self): - params = { + def test_POST_registration_token_invalid(self) -> None: + params: JsonDict = { "username": "kermit", "password": "monkey", } @@ -295,28 +301,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_limit_uses(self): + def test_POST_registration_token_limit_uses(self) -> None: token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that can be used once self.get_success( store.db_pool.simple_insert( @@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] @@ -354,7 +360,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 1) + self.assertEqual(pending, 1) # Check auth fails when using token with session2 params2["auth"] = { @@ -363,9 +369,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session2, } channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY @@ -378,20 +384,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcols=["pending", "completed"], ) ) - self.assertEquals(res["pending"], 0) - self.assertEquals(res["completed"], 1) + self.assertEqual(res["pending"], 0) + self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_expiry(self): + def test_POST_registration_token_expiry(self) -> None: token = "abcd" now = self.hs.get_clock().time_msec() - store = self.hs.get_datastore() + store = self.hs.get_datastores().main # Create token that expired yesterday self.get_success( store.db_pool.simple_insert( @@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, ) ) - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -417,9 +423,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) - self.assertEquals(channel.json_body["completed"], []) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(channel.json_body["completed"], []) # Update token so it expires tomorrow self.get_success( @@ -436,10 +442,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry(self): + def test_POST_registration_token_session_expiry(self) -> None: """Test `pending` is decremented when an uncompleted session expires.""" token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do 2 requests without auth to get two session IDs - params1 = {"username": "bert", "password": "monkey"} - params2 = {"username": "ernie", "password": "monkey"} + params1: JsonDict = {"username": "bert", "password": "monkey"} + params2: JsonDict = {"username": "ernie", "password": "monkey"} channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) session1 = channel1.json_body["session"] channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) @@ -504,7 +510,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="result", ) ) - self.assertEquals(db_to_json(result2), token) + self.assertEqual(db_to_json(result2), token) # Delete both sessions (mimics expiry) self.get_success( @@ -519,10 +525,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): retcol="pending", ) ) - self.assertEquals(pending, 0) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) - def test_POST_registration_token_session_expiry_deleted_token(self): + def test_POST_registration_token_session_expiry_deleted_token(self) -> None: """Test session expiry doesn't break when the token is deleted. 1. Start but don't complete UIA with a registration token @@ -530,7 +536,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): 3. Expire the session """ token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Do request without auth to get a session ID - params = {"username": "kermit", "password": "monkey"} + params: JsonDict = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, json.dumps(params)) session = channel.json_body["session"] @@ -570,9 +576,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) ) - def test_advertised_flows(self): + def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -593,9 +599,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): + def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -625,9 +631,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_advertised_flows_no_msisdn_email_required(self): + def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_request_token_existing_email_inhibit_error(self): + def test_request_token_existing_email_inhibit_error(self) -> None: """Test that requesting a token via this endpoint doesn't leak existing associations if configured that way. """ @@ -657,7 +663,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Add a threepid self.get_success( - self.hs.get_datastore().user_add_threepid( + self.hs.get_datastores().main.user_add_threepid( user_id=user_id, medium="email", address=email, @@ -671,7 +677,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): }, } ) - def test_reject_invalid_email(self): + def test_reject_invalid_email(self) -> None: """Check that bad emails are rejected""" # Test for email with multiple @ @@ -694,9 +700,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) + self.assertEqual(400, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -707,8 +713,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -720,8 +726,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( + self.assertEqual(400, channel.code, channel.result) + self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) @@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "inhibit_user_in_use_error": True, } ) - def test_inhibit_user_in_use_error(self): + def test_inhibit_user_in_use_error(self) -> None: """Tests that the 'inhibit_user_in_use_error' configuration flag behaves correctly. """ @@ -745,7 +751,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): account_validity.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week. config["enable_registration"] = True @@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): return self.hs - def test_validity_period(self): + def test_validity_period(self) -> None: self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -799,18 +805,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_manual_renewal(self): + def test_manual_renewal(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -826,14 +832,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): params = {"user_id": user_id} request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_manual_expire(self): + def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -848,17 +854,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) - def test_logging_out_expired_user(self): + def test_logging_out_expired_user(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -873,18 +879,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Test for account expiring after a week and renewal emails being sent 2 @@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) - async def sendmail(*args, **kwargs): + async def sendmail(*args: Any, **kwargs: Any) -> None: self.email_attempts.append((args, kwargs)) - self.email_attempts = [] + self.email_attempts: List[Tuple[Any, Any]] = [] self.hs.get_send_email_handler()._sendmail = sendmail - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs - def test_renewal_email(self): + def test_renewal_email(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -959,7 +965,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -977,7 +983,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -997,14 +1003,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) - def test_renewal_invalid_token(self): + def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): channel.result["body"], expected_html.encode("utf8"), channel.result ) - def test_manual_email_send(self): + def test_manual_email_send(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1028,11 +1034,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) - def test_deactivated_user(self): + def test_deactivated_user(self) -> None: self.email_attempts = [] (user_id, tok) = self.create_user() @@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEqual(len(self.email_attempts), 0) - def create_user(self): + def create_user(self) -> Tuple[str, str]: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do @@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): ) return user_id, tok - def test_manual_email_send_expired_account(self): + def test_manual_email_send_expired_account(self) -> None: user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") @@ -1103,7 +1109,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.validity_period = 10 self.max_delta = self.validity_period * 10.0 / 100.0 @@ -1126,14 +1132,16 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): # We need to set these directly, instead of in the homeserver config dict above. # This is due to account validity-related config options not being read by # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + self.hs.get_datastores().main._account_validity_period = self.validity_period + self.hs.get_datastores().main._account_validity_startup_job_max_delta = ( + self.max_delta + ) - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main return self.hs - def test_background_job(self): + def test_background_job(self) -> None: """ Tests the same thing as test_background_job, except that it sets the startup_job_max_delta parameter and checks that the expiration date is within the @@ -1156,14 +1164,14 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = "/_matrix/client/v1/register/m.login.registration_token/validity" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["registration_requires_token"] = True return config - def test_GET_token_valid(self): + def test_GET_token_valid(self) -> None: token = "abcd" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "registration_tokens", @@ -1181,22 +1189,22 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], True) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], True) - def test_GET_token_invalid(self): + def test_GET_token_invalid(self) -> None: token = "1234" channel = self.make_request( b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertEquals(channel.json_body["valid"], False) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["valid"], False) @override_config( {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} ) - def test_GET_ratelimiting(self): + def test_GET_ratelimiting(self) -> None: token = "1234" for i in range(0, 6): @@ -1206,10 +1214,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1217,4 +1225,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index de80aca03..c8db45719 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -18,11 +18,15 @@ import urllib.parse from typing import Dict, List, Optional, Tuple from unittest.mock import patch +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.server import HomeServer from synapse.storage.relations import RelationPaginationToken from synapse.types import JsonDict, StreamToken +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -52,8 +56,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -63,13 +67,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] - def test_send_relation(self): + def test_send_relation(self) -> None: """Tests that sending a relation using the new /send_relation works creates the right shape of event. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -78,7 +82,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assert_dict( { @@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_invalid_event(self): + def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" channel = self._send_relation( RelationTypes.ANNOTATION, @@ -103,11 +107,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_relations", values={ "event_id": "bar", @@ -123,9 +127,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - def test_deny_invalid_room(self): + def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -136,17 +140,17 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_deny_double_react(self): + def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_deny_forked_thread(self): + def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" channel = self._send_relation( RelationTypes.THREAD, @@ -154,7 +158,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] channel = self._send_relation( @@ -163,16 +167,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) - def test_basic_paginate_relations(self): + def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) second_annotation_id = channel.json_body["event_id"] channel = self.make_request( @@ -180,11 +184,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the latest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": second_annotation_id, @@ -195,7 +199,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -212,11 +216,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the earliest # full relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( { "event_id": first_annotation_id, @@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ).to_string(self.store) ) - def test_repeated_paginate_relations(self): + def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. """ @@ -245,7 +249,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) prev_token = "" @@ -260,12 +264,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -273,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -288,12 +292,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -301,12 +305,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) - def test_pagination_from_sync_and_messages(self): + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Send an event after the relation events. self.helper.send(self.room, body="Latest event", tok=self.user_token) @@ -319,7 +323,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] sync_prev_batch = room_timeline["prev_batch"] self.assertIsNotNone(sync_prev_batch) @@ -335,7 +339,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) messages_end = channel.json_body["end"] self.assertIsNotNone(messages_end) # Ensure the relation event is not in the chunk returned from /messages. @@ -355,14 +359,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The relation should be in the returned chunk. self.assertIn( annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self): + def test_aggregation_pagination_groups(self) -> None: """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. @@ -386,7 +390,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): key=key, access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) idx += 1 idx %= len(access_tokens) @@ -404,7 +408,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -419,15 +423,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: break - self.assertEquals(sent_groups, found_groups) + self.assertEqual(sent_groups, found_groups) - def test_aggregation_pagination_within_group(self): + def test_aggregation_pagination_within_group(self) -> None: """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. @@ -449,14 +453,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): key="👍", access_token=access_tokens[idx], ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) idx += 1 # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) prev_token = "" found_event_ids: List[str] = [] @@ -473,7 +477,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -481,7 +485,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -489,7 +493,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) # Reset and try again, but convert the tokens to the legacy format. prev_token = "" @@ -506,7 +510,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -514,7 +518,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): next_batch = channel.json_body.get("next_batch") - self.assertNotEquals(prev_token, next_batch) + self.assertNotEqual(prev_token, next_batch) prev_token = next_batch if not prev_token: @@ -522,21 +526,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We paginated backwards, so reverse found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) + self.assertEqual(found_event_ids, expected_event_ids) - def test_aggregation(self): + def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -544,9 +548,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, { "chunk": [ @@ -556,17 +560,17 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) - def test_aggregation_redactions(self): + def test_aggregation_redactions(self) -> None: """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions channel = self.make_request( @@ -575,7 +579,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content={}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -583,14 +587,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual( channel.json_body, {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - def test_aggregation_must_be_annotation(self): + def test_aggregation_must_be_annotation(self) -> None: """Test that aggregations must be annotations.""" channel = self.make_request( @@ -599,12 +603,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.assertEquals(400, channel.code, channel.json_body) + self.assertEqual(400, channel.code, channel.json_body) @unittest.override_config( {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} ) - def test_bundled_aggregations(self): + def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -615,29 +619,29 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -655,7 +659,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) # Check the values of each field. - self.assertEquals( + self.assertEqual( { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, @@ -665,12 +669,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): relations_dict[RelationTypes.ANNOTATION], ) - self.assertEquals( + self.assertEqual( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, relations_dict[RelationTypes.REFERENCE], ) - self.assertEquals( + self.assertEqual( 2, relations_dict[RelationTypes.THREAD].get("count"), ) @@ -701,7 +705,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) # Request the room messages. @@ -710,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -719,12 +723,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -737,7 +741,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -746,47 +750,47 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_aggregation_get_event_for_annotation(self): + def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{annotation_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): + def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", f"/rooms/{self.room}/event/{thread_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -801,11 +805,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1) thread_message = channel.json_body["chunk"][0] - self.assertEquals( + self.assertEqual( thread_message["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { @@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_ignore_invalid_room(self): + def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) @@ -905,7 +909,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And when fetching aggregations. @@ -914,7 +918,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) # And for bundled aggregations. @@ -923,11 +927,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{room2}/event/{parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_edit(self): + def test_edit(self) -> None: """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -936,7 +940,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -958,8 +962,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["content"], new_body) assert_bundle(channel.json_body) # Request the room messages. @@ -968,7 +972,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/messages?dir=b", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. @@ -977,7 +981,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) # Request sync, but limit the timeline so it becomes limited (and includes @@ -988,7 +992,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", f"/sync?filter={filter}", access_token=self.user_token ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) @@ -1001,7 +1005,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"search_categories": {"room_events": {"search_term": "Hi"}}}, access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) chunk = [ result["result"] for result in channel.json_body["search_categories"]["room_events"][ @@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_multi_edit(self): + def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ @@ -1024,7 +1028,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -1032,7 +1036,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1045,16 +1049,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) @@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_reply(self): + def test_edit_reply(self) -> None: """Test that editing a reply works.""" # Create a reply to edit. @@ -1076,7 +1080,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1086,7 +1090,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] @@ -1095,7 +1099,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, reply), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to see the new body in the dict, as well as the reference # metadata sill intact. @@ -1123,7 +1127,47 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_edit(self): + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_edit_thread(self) -> None: + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + + def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} channel = self._send_relation( @@ -1135,7 +1179,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": new_body, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. @@ -1149,7 +1193,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, parent_id=edit_event_id, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1157,9 +1201,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) + self.assertEqual(channel.json_body["content"], new_body) # The relations information should not include the edit to the edit. relations_dict = channel.json_body["unsigned"].get("m.relations") @@ -1173,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_relations_redaction_redacts_edits(self): + def test_relations_redaction_redacts_edits(self) -> None: """Test that edits of an event are redacted when the original event is redacted. """ @@ -1192,7 +1236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned channel = self.make_request( @@ -1201,10 +1245,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) + self.assertEqual(len(channel.json_body["chunk"]), 1) # Redact the original event channel = self.make_request( @@ -1214,7 +1258,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations channel = self.make_request( @@ -1223,13 +1267,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that no relations are returned self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) - def test_aggregations_redaction_prevents_access_to_aggregations(self): + def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None: """Test that annotations of an event are redacted when the original event is redacted. """ @@ -1241,7 +1285,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Redact the original channel = self.make_request( @@ -1255,7 +1299,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Check that aggregations returns zero channel = self.make_request( @@ -1264,15 +1308,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(channel.json_body["chunk"], []) - def test_unknown_relations(self): + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1281,18 +1325,18 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, channel.json_body["chunk"][0], ) # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( + self.assertEqual( channel.json_body["original_event"]["event_id"], self.parent_id ) @@ -1302,7 +1346,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) # But unknown relations can be directly queried. @@ -1312,8 +1356,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ @@ -1377,18 +1421,18 @@ class RelationsTestCase(unittest.HomeserverTestCase): return user_id, access_token - def test_background_update(self): + def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1408,8 +1452,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals( + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good], ) @@ -1433,7 +1477,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) self.assertCountEqual( [ev["event_id"] for ev in channel.json_body["chunk"]], [annotation_event_id_good, thread_event_id], diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index fe5b536d9..f3bf8d093 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -13,9 +13,14 @@ # limitations under the License. from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from synapse.visibility import filter_events_for_client from tests import unittest @@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -47,15 +52,15 @@ class RetentionTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.serializer = self.hs.get_event_client_serializer() self.clock = self.hs.get_clock() - def test_retention_event_purged_with_state_event(self): + def test_retention_event_purged_with_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by a state event. """ @@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 1.5) - def test_retention_event_purged_with_state_event_outside_allowed(self): + def test_retention_event_purged_with_state_event_outside_allowed(self) -> None: """Tests that the server configuration can override the policy for a room when running the purge jobs. """ @@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # instead of the one specified in the room's policy. self._test_retention_event_purged(room_id, one_day_ms * 0.5) - def test_retention_event_purged_without_state_event(self): + def test_retention_event_purged_without_state_event(self) -> None: """Tests that expired events are correctly purged when the room's retention policy is defined by the server's configuration's default retention policy. """ @@ -110,11 +115,11 @@ class RetentionTestCase(unittest.HomeserverTestCase): self._test_retention_event_purged(room_id, one_day_ms * 2) - def test_visibility(self): + def test_visibility(self) -> None: """Tests that synapse.visibility.filter_events_for_client correctly filters out outdated events """ - store = self.hs.get_datastore() + store = self.hs.get_datastores().main storage = self.hs.get_storage() room_id = self.helper.create_room_as(self.user_id, tok=self.token) events = [] @@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id: str, increment: float): + def _test_retention_event_purged(self, room_id: str, increment: float) -> None: """Run the following test scenario to test the message retention policy support: 1. Send event 1 @@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") + assert expired_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(expired_event_id) @@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert valid_event_id is not None # Advance the time again. Now our first event should have expired but our second # one should still be kept. @@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, event_id, expect_none=False): + def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) if expect_none: @@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["retention"] = { "enabled": True, @@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_no_default_policy(self): + def test_no_default_policy(self) -> None: """Tests that an event doesn't get expired if there is neither a default retention policy nor a policy specific to the room. """ @@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id) - def test_state_policy(self): + def test_state_policy(self) -> None: """Tests that an event gets correctly expired if there is no default retention policy but there's a policy specific to the room. """ @@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self._test_retention(room_id, expected_code_for_first_event=404) - def _test_retention(self, room_id, expected_code_for_first_event=200): + def _test_retention( + self, room_id: str, expected_code_for_first_event: int = 200 + ) -> None: # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert first_event_id is not None # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) @@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") + assert second_event_id is not None # Advance the time by another month. self.reactor.advance(one_day_ms * 30 / 1000) @@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): second_event = self.get_event(room_id, second_event_id) self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) - def get_event(self, room_id, event_id, expected_code=200): + def get_event( + self, room_id: str, event_id: str, expected_code: int = 200 + ) -> JsonDict: url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) channel = self.make_request("GET", url, access_token=self.token) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index e9f870403..44f333a0e 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): return room_id, event_id_a, event_id_b, event_id_c @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self): + def test_same_state_groups_for_whole_historical_batch(self) -> None: """Make sure that when using the `/batch_send` endpoint to import a bunch of historical messages, it re-uses the same `state_group` across the whole batch. This is an easy optimization to make sure we're getting diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b7f086927..e0b11e726 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - self.hs.get_datastore().insert_client_ip = _insert_client_ip + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip return self.hs @@ -95,7 +95,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -103,7 +103,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -125,28 +125,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' @@ -156,28 +156,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -185,25 +185,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -211,7 +211,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room=None, members: Iterable = frozenset(), expect_code=None @@ -219,7 +219,7 @@ class RoomPermissionsTestCase(RoomBase): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) + self.assertEqual(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -478,16 +478,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self): """ @@ -498,7 +498,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -507,7 +507,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self): """ @@ -520,14 +520,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self): """ @@ -541,14 +541,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -560,14 +560,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" @@ -576,17 +576,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): @@ -598,19 +598,19 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): @@ -618,16 +618,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -635,7 +635,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEquals(400, channel.code) + self.assertEqual(400, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self): @@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase): # Add the current user to the ratelimit overrides, allowing them no ratelimiting. self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) ) # Test that the invites aren't ratelimited anymore. @@ -694,9 +694,9 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) - self.assertEquals(join_mock.call_count, 0) + self.assertEqual(join_mock.call_count, 0) class RoomTopicTestCase(RoomBase): @@ -712,54 +712,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -775,22 +775,22 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -799,7 +799,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -810,13 +810,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) + self.assertEqual(expected_response, channel.json_body) def test_rooms_members_other(self): self.other_id = "@zzsid1:red" @@ -828,11 +828,11 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" @@ -847,11 +847,11 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(json.loads(content), channel.json_body) class RoomInviteRatelimitTestCase(RoomBase): @@ -937,7 +937,7 @@ class RoomJoinTestCase(RoomBase): False, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -955,7 +955,7 @@ class RoomJoinTestCase(RoomBase): True, ), ) - self.assertEquals( + self.assertEqual( callback_mock.call_args, expected_call_args, callback_mock.call_args, @@ -1013,7 +1013,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1023,10 +1023,10 @@ class RoomJoinRatelimitTestCase(RoomBase): ) channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") + self.assertEqual(channel.json_body["displayname"], "John Doe") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} @@ -1047,7 +1047,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # if all of these requests ended up joining the user to a room. for _ in range(4): channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) @unittest.override_config( { @@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = self.register_user("testuser", "password") # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + rooms = self.get_success( + self.hs.get_datastores().main.get_rooms_for_user(user_id) + ) self.assertEqual(len(rooms), 4) @@ -1076,40 +1078,40 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -1123,10 +1125,10 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self): channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) + self.assertEqual(self.room_id, channel.json_body["room_id"]) + self.assertEqual("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} @@ -1150,7 +1152,7 @@ class RoomInitialSyncTestCase(RoomBase): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): @@ -1166,9 +1168,9 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -1177,14 +1179,14 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) + self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) def test_room_messages_purge(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main pagination_handler = self.hs.get_pagination_handler() # Send a first message in the room, which will be removed by the purge. @@ -2612,7 +2614,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 200) + self.assertEqual(channel.code, 200) # Check that the callback was called with the right params. mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) @@ -2634,7 +2636,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): }, access_token=self.tok, ) - self.assertEquals(channel.code, 403) + self.assertEqual(channel.code, 403) # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index e2ed14457..c3942889e 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase): sync.register_servlets, ] - def test_user_to_user(self): + def test_user_to_user(self) -> None: """A to-device message from one user to another should get delivered""" user1 = self.register_user("u1", "pass") @@ -73,7 +73,7 @@ class SendToDeviceTestCase(HomeserverTestCase): self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): + def test_local_room_key_request(self) -> None: """m.room_key_request has special-casing; test from local user""" user1 = self.register_user("u1", "pass") user1_tok = self.login("u1", "pass", "d1") @@ -128,7 +128,7 @@ class SendToDeviceTestCase(HomeserverTestCase): ) @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): + def test_remote_room_key_request(self) -> None: """m.room_key_request has special-casing; test from remote user""" user2 = self.register_user("u2", "pass") user2_tok = self.login("u2", "pass", "d2") @@ -199,7 +199,7 @@ class SendToDeviceTestCase(HomeserverTestCase): }, ) - def test_limited_sync(self): + def test_limited_sync(self) -> None: """If a limited sync for to-devices happens the next /sync should respond immediately.""" self.register_user("u1", "pass") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index b0c44af03..ae5ada3be 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -14,6 +14,8 @@ from unittest.mock import Mock, patch +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import ( @@ -23,18 +25,20 @@ from synapse.rest.client import ( room, room_upgrade_rest_servlet, ) +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class _ShadowBannedBase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create two users, one of which is shadow-banned. self.banned_user_id = self.register_user("banned", "test") self.banned_access_token = self.login("banned", "test") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) @@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase): room_upgrade_rest_servlet.register_servlets, ] - def test_invite(self): + def test_invite(self) -> None: """Invites from shadow-banned users don't actually get sent.""" # The create works fine. @@ -77,7 +81,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertEqual(invited_rooms, []) - def test_invite_3pid(self): + def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( @@ -96,12 +100,12 @@ class RoomTestCase(_ShadowBannedBase): {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. identity_handler.lookup_3pid.assert_not_called() - def test_create_room(self): + def test_create_room(self) -> None: """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. channel = self.make_request( @@ -110,7 +114,7 @@ class RoomTestCase(_ShadowBannedBase): {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) room_id = channel.json_body["room_id"] # But the user wasn't actually invited. @@ -126,7 +130,7 @@ class RoomTestCase(_ShadowBannedBase): users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) - def test_message(self): + def test_message(self) -> None: """Messages from shadow-banned users don't actually get sent.""" room_id = self.helper.create_room_as( @@ -151,7 +155,7 @@ class RoomTestCase(_ShadowBannedBase): ) self.assertNotIn(event_id, latest_events) - def test_upgrade(self): + def test_upgrade(self) -> None: """A room upgrade should fail, but look like it succeeded.""" # The create works fine. @@ -165,7 +169,7 @@ class RoomTestCase(_ShadowBannedBase): {"new_version": "6"}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -177,7 +181,7 @@ class RoomTestCase(_ShadowBannedBase): # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) - def test_typing(self): + def test_typing(self) -> None: """Typing notifications should not be propagated into the room.""" # The create works fine. room_id = self.helper.create_room_as( @@ -190,11 +194,11 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # There should be no typing events. event_source = self.hs.get_event_sources().sources.typing - self.assertEquals(event_source.get_current_key(), 0) + self.assertEqual(event_source.get_current_key(), 0) # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) @@ -205,10 +209,10 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # These appear in the room. - self.assertEquals(event_source.get_current_key(), 1) + self.assertEqual(event_source.get_current_key(), 1) events = self.get_success( event_source.get_new_events( user=UserID.from_string(self.other_user_id), @@ -218,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): room.register_servlets, ] - def test_displayname(self): + def test_displayname(self) -> None: """Profile changes should succeed, but don't end up in a room.""" original_display_name = "banned" new_display_name = "new name" @@ -257,7 +261,7 @@ class ProfileTestCase(_ShadowBannedBase): {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) # The user's display name should be updated. @@ -281,7 +285,7 @@ class ProfileTestCase(_ShadowBannedBase): event.content, {"membership": "join", "displayname": original_display_name} ) - def test_room_displayname(self): + def test_room_displayname(self) -> None: """Changes to state events for a room should be processed, but not end up in the room.""" original_display_name = "banned" new_display_name = "new name" @@ -299,7 +303,7 @@ class ProfileTestCase(_ShadowBannedBase): {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) # The display name in the room should not be changed. diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index 283eccd53..3818b7b14 100644 --- a/tests/rest/client/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, shared_rooms +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): shared_rooms.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user) -> FakeChannel: + def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" @@ -47,14 +51,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): access_token=token, ) - def test_shared_room_list_public(self): + def test_shared_room_list_public(self) -> None: """ A room should show up in the shared list of rooms between two users if it is public. """ self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - def test_shared_room_list_private(self): + def test_shared_room_list_private(self) -> None: """ A room should show up in the shared list of rooms between two users if it is private. @@ -63,7 +67,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): room_one_is_public=False, room_two_is_public=False ) - def test_shared_room_list_mixed(self): + def test_shared_room_list_mixed(self) -> None: """ The shared room list between two users should contain both public and private rooms. @@ -72,7 +76,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): def _check_shared_rooms_with( self, room_one_is_public: bool, room_two_is_public: bool - ): + ) -> None: """Checks that shared public or private rooms between two users appear in their shared room lists """ @@ -91,9 +95,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms from user1's perspective. # We should see the one room in common channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room_id_one) # Create another room and invite user2 to it room_id_two = self.helper.create_room_as( @@ -104,12 +108,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 2) for room_id_id in channel.json_body["joined"]: self.assertIn(room_id_id, [room_id_one, room_id_two]) - def test_shared_room_list_after_leave(self): + def test_shared_room_list_after_leave(self) -> None: """ A room should no longer be considered shared if the other user has left it. @@ -125,18 +129,18 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): # Assert user directory is not empty channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 1) + self.assertEqual(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) # Check user1's view of shared rooms with user2 channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index cd4af2b1f..435101395 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import List, Optional from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import ( EventContentFields, @@ -24,6 +27,9 @@ from synapse.api.constants import ( RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.federation.transport.test_knocking import ( @@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_argless(self): + def test_sync_argless(self) -> None: channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) @@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_sync_filter_labels(self): + def test_sync_filter_labels(self) -> None: """Test that we can filter by a label.""" sync_filter = json.dumps( { @@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - def test_sync_filter_not_labels(self): + def test_sync_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { @@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events[2]["content"]["body"], "with two wrong labels", events[2] ) - def test_sync_filter_labels_not_labels(self): + def test_sync_filter_labels_not_labels(self) -> None: """Test that we can filter by both a label and the absence of another label.""" sync_filter = json.dumps( { @@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - def _test_sync_filter_labels(self, sync_filter): + def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): user_id = True hijack_auth = False - def test_sync_backwards_typing(self): + def test_sync_backwards_typing(self) -> None: """ If the typing serial goes backwards and the typing handler is then reset (such as when the master restarts and sets the typing serial to 0), we @@ -231,10 +237,10 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. @@ -243,7 +249,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Start typing. channel = self.make_request( @@ -251,11 +257,11 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) # Should return immediately channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Reset typing serial back to 0, as if the master had. @@ -267,7 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # This should time out! But it does not, because our stream token is @@ -275,7 +281,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # already seen) is new, since it's got a token above our new, now-reset # stream token. channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) next_batch = channel.json_body["next_batch"] # Clear the typing information, so that it doesn't think everything is @@ -298,8 +304,8 @@ class SyncKnockTestCase( knock.register_servlets, ] - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.url = "/sync?since=%s" self.next_batch = "s0" @@ -336,7 +342,7 @@ class SyncKnockTestCase( ) @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): + def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room channel = self.make_request( @@ -345,7 +351,7 @@ class SyncKnockTestCase( b"{}", self.knocker_tok, ) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) # We expect to see the knock event in the stripped room state later self.expected_room_state[EventTypes.Member] = { @@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): + def test_hidden_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ] ) def test_read_receipt_with_empty_body( - self, name, user_agent: str, expected_status_code: int - ): + self, name: str, user_agent: str, expected_status_code: int + ) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_status_code) - def _get_read_receipt(self): + def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" # Checks if event is a read receipt - def is_read_receipt(event): + def is_read_receipt(event: JsonDict) -> bool: return event["type"] == "m.receipt" # Sync @@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - def test_unread_counts(self): + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): ) self._check_unread_count(5) - def _check_unread_count(self, expected_count: int): + def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" channel = self.make_request( @@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def test_noop_sync_does_not_tightloop(self): + def test_noop_sync_does_not_tightloop(self) -> None: """If the sync times out, we shouldn't cache the result Essentially a regression test for #8518. @@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): devices.register_servlets, ] - def test_user_with_no_rooms_receives_self_device_list_updates(self): + def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: """Tests that a user with no rooms still receives their own device list updates""" device_id = "TESTDEVICE" diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index ac6b86ff6..bfc04785b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -15,12 +15,12 @@ import threading from typing import TYPE_CHECKING, Dict, Optional, Tuple from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin -from synapse.rest.client import login, room +from synapse.rest.client import account, login, profile, room from synapse.types import JsonDict, Requester, StateMap from synapse.util.frozenutils import unfreeze @@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): admin.register_servlets, login.register_servlets, room.register_servlets, + profile.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -139,7 +141,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) callback.assert_called_once() @@ -157,7 +159,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.result["code"], b"403", channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self): """ @@ -193,7 +195,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEquals(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.result["code"], b"429", channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -329,10 +331,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) - self.assertEquals(event.sender, self.user_id) - self.assertEquals(event.room_id, self.room_id) - self.assertEquals(event.type, "m.room.message") - self.assertEquals(event.content, content) + self.assertEqual(event.sender, self.user_id) + self.assertEqual(event.room_id, self.room_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.content, content) @unittest.override_config( { @@ -530,3 +532,216 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, tok=self.tok, ) + + def test_on_profile_update(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Change the display name. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/displayname" % self.user_id, + {"displayname": displayname}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + m.assert_called_once() + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertIsNone(profile_info.avatar_url) + + # Change the avatar. + channel = self.make_request( + "PUT", + "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id, + {"avatar_url": avatar_url}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called once for our user. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Test that by_admin is False. + self.assertFalse(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_profile_update_admin(self): + """Tests that the on_profile_update module callback is correctly called on + profile updates triggered by a server admin. + """ + displayname = "Foo" + avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" + + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Change a user's profile. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % self.user_id, + {"displayname": displayname, "avatar_url": avatar_url}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the callback has been called twice (since we update the display name + # and avatar separately). + self.assertEqual(m.call_count, 2) + + # Get the arguments for the last call and check it's about the right user. + args = m.call_args[0] + self.assertEqual(args[0], self.user_id) + + # Check that by_admin is True. + self.assertTrue(args[2]) + # Test that deactivation is False. + self.assertFalse(args[3]) + + # Check that we've got the right profile data. + profile_info = args[1] + self.assertEqual(profile_info.display_name, displayname) + self.assertEqual(profile_info.avatar_url, avatar_url) + + def test_on_user_deactivation_status_changed(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append( + deactivation_mock, + ) + # Also register a mocked callback for profile updates, to check that the + # deactivation code calls it in a way that let modules know the user is being + # deactivated. + profile_mock = Mock(return_value=make_awaitable(None)) + self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append( + profile_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID, and with a True + # deactivated flag and a False by_admin flag. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertFalse(args[2]) + + # Check that the profile update callback was called twice (once for the display + # name and once for the avatar URL), and that the "deactivation" boolean is true. + self.assertEqual(profile_mock.call_count, 2) + args = profile_mock.call_args[0] + self.assertTrue(args[3]) + + def test_on_user_deactivation_status_changed_admin(self): + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin as + well as a reactivation. + """ + # Register a mock callback. + m = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + m.assert_called_once() + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with True deactivated + # and by_admin flags. + self.assertEqual(args[0], user_id) + self.assertTrue(args[1]) + self.assertTrue(args[2]) + + # Reactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": False, "password": "hackme"}, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + self.assertEqual(m.call_count, 2) + args = m.call_args[0] + + # Check that the mock was called with the right user ID, and with a False + # deactivated flag and a True by_admin flag. + self.assertEqual(args[0], user_id) + self.assertFalse(args[1]) + self.assertTrue(args[2]) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index ee0abd529..8b2da88e8 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): async def _insert_client_ip(*args, **kwargs): return None - hs.get_datastore().insert_client_ip = _insert_client_ip + hs.get_datastores().main.insert_client_ip = _insert_client_ip return hs @@ -72,9 +72,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( user=UserID.from_string(self.user_id), @@ -84,7 +84,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): is_guest=False, ) ) - self.assertEquals( + self.assertEqual( events[0], [ { @@ -101,7 +101,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) def test_typing_timeout(self): channel = self.make_request( @@ -109,19 +109,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEqual(self.event_source.get_current_key(), 1) self.reactor.advance(36) - self.assertEquals(self.event_source.get_current_key(), 2) + self.assertEqual(self.event_source.get_current_key(), 2) channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.assertEquals(200, channel.code) + self.assertEqual(200, channel.code) - self.assertEquals(self.event_source.get_current_key(), 3) + self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a42388b26..b7d0f42da 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -13,11 +13,14 @@ # limitations under the License. from typing import Optional +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventContentFields, EventTypes, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel @@ -31,8 +34,8 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): room_upgrade_rest_servlet.register_servlets, ] - def prepare(self, reactor, clock, hs: "HomeServer"): - self.store = hs.get_datastore() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") @@ -60,15 +63,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): access_token=token or self.creator_token, ) - def test_upgrade(self): + def test_upgrade(self) -> None: """ Upgrading a room should work fine. """ channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self): + def test_not_in_room(self) -> None: """ Upgrading a room should work fine. """ @@ -77,15 +80,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): roomless_token = self.login(roomless, "pass") channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) - def test_power_levels(self): + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -103,15 +106,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) - def test_power_levels_user_default(self): + def test_power_levels_user_default(self) -> None: """ Another user can upgrade the room if the default power level for users is increased. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -129,15 +132,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) - def test_power_levels_tombstone(self): + def test_power_levels_tombstone(self) -> None: """ Another user can upgrade the room if they can send the tombstone event. """ # The other user doesn't have the proper power level. channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) + self.assertEqual(403, channel.code, channel.result) # Increase the power levels so that this user can upgrade. power_levels = self.helper.get_state( @@ -155,7 +158,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # The upgrade should succeed! channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) power_levels = self.helper.get_state( self.room_id, @@ -164,7 +167,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): ) self.assertNotIn(self.other, power_levels["users"]) - def test_space(self): + def test_space(self) -> None: """Test upgrading a space.""" # Create a space. @@ -197,7 +200,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): # Upgrade the room! channel = self._upgrade_room(room_id=space_id) - self.assertEquals(200, channel.code, channel.result) + self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) new_space_id = channel.json_body["replacement_room"] diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1c0cb0cf4..28663826f 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,6 +19,7 @@ import json import re import time import urllib.parse +from http import HTTPStatus from typing import ( Any, AnyStr, @@ -40,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request @@ -47,15 +49,15 @@ from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser -@attr.s +@attr.s(auto_attribs=True) class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() + hs: HomeServer + site: Site + auth_user_id: Optional[str] @overload def create_room_as( @@ -89,7 +91,7 @@ class RestHelper: is_public: Optional[bool] = None, room_version: Optional[str] = None, tok: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, extra_content: Optional[Dict] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ) -> Optional[str]: @@ -106,9 +108,13 @@ class RestHelper: default room version. tok: The access token to use in the request. expect_code: The expected HTTP response code. + extra_content: Extra keys to include in the body of the /createRoom request. + Note that if is_public is set, the "visibility" key will be overridden. + If room_version is set, the "room_version" key will be overridden. + custom_headers: HTTP headers to include in the request. Returns: - The ID of the newly created room. + The ID of the newly created room, or None if the request failed. """ temp_id = self.auth_user_id self.auth_user_id = room_creator @@ -133,12 +139,19 @@ class RestHelper: assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - if expect_code == 200: + if expect_code == HTTPStatus.OK: return channel.json_body["room_id"] else: return None - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + def invite( + self, + room: str, + src: Optional[str] = None, + targ: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=src, @@ -152,7 +165,7 @@ class RestHelper: self, room: str, user: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, ) -> None: @@ -166,7 +179,14 @@ class RestHelper: expect_code=expect_code, ) - def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): + def knock( + self, + room: Optional[str] = None, + user: Optional[str] = None, + reason: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: temp_id = self.auth_user_id self.auth_user_id = user path = "/knock/%s" % room @@ -195,7 +215,13 @@ class RestHelper: self.auth_user_id = temp_id - def leave(self, room=None, user=None, expect_code=200, tok=None): + def leave( + self, + room: str, + user: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, @@ -205,14 +231,22 @@ class RestHelper: expect_code=expect_code, ) - def ban(self, room: str, src: str, targ: str, **kwargs: object): + def ban( + self, + room: str, + src: str, + targ: str, + expect_code: int = HTTPStatus.OK, + tok: Optional[str] = None, + ) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" self.change_membership( room=room, src=src, targ=targ, + tok=tok, membership=Membership.BAN, - **kwargs, + expect_code=expect_code, ) def change_membership( @@ -224,7 +258,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, ) -> None: """ @@ -290,13 +324,13 @@ class RestHelper: def send( self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, + room_id: str, + body: Optional[str] = None, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if body is None: body = "body_text_here" @@ -314,14 +348,14 @@ class RestHelper: def send_event( self, - room_id, - type, + room_id: str, + type: str, content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, + txn_id: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -353,11 +387,11 @@ class RestHelper: room_id: str, event_type: str, body: Optional[Dict[str, Any]], - tok: str, - expect_code: int = 200, + tok: Optional[str], + expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", - ) -> Dict: + ) -> JsonDict: """Read or write some state from a given room Args: @@ -406,9 +440,9 @@ class RestHelper: room_id: str, event_type: str, tok: str, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Gets some state from a room Args: @@ -433,10 +467,10 @@ class RestHelper: room_id: str, event_type: str, body: Dict[str, Any], - tok: str, - expect_code: int = 200, + tok: Optional[str], + expect_code: int = HTTPStatus.OK, state_key: str = "", - ): + ) -> JsonDict: """Set some state in a room Args: @@ -463,8 +497,8 @@ class RestHelper: image_data: bytes, tok: str, filename: str = "test.png", - expect_code: int = 200, - ) -> dict: + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: """Upload a piece of test media to the media repo Args: resource: The resource that will handle the upload request @@ -509,7 +543,7 @@ class RestHelper: channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200, channel.result + assert channel.code == HTTPStatus.OK, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -528,7 +562,7 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == 200 + assert channel.code == HTTPStatus.OK return channel.json_body def auth_via_oidc( @@ -633,11 +667,16 @@ class RestHelper: (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] - async def mock_req(method: str, uri: str, data=None, headers=None): + async def mock_req( + method: str, + uri: str, + data: Optional[dict] = None, + headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ): (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, + code=HTTPStatus.OK, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), ) @@ -735,7 +774,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint ) # that should serve a confirmation page - assert channel.code == 200, channel.text_body + assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4cf1ed5dd..cba9be17c 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -94,7 +94,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) # Asserts the file is under the expected local cache directory - self.assertEquals( + self.assertEqual( os.path.commonprefix([self.primary_base_path, local_path]), self.primary_base_path, ) @@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 36c495954..02b96c9e6 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -242,7 +242,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 36c933b9e..50c20c5b9 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") def test_background_remove_deleted_devices_from_device_inbox(self): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5..1f6a9eb07 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -37,7 +37,7 @@ from tests import unittest class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main # insert some test data for rid in ("room1", "room2"): @@ -88,18 +88,18 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) + self.assertEqual(res, {"event10"}) # that should result in a single db query - self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( self.store.have_seen_events("room1", ["event10", "event19"]) ) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache @@ -108,8 +108,8 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEquals(res, {"event10"}) - self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + self.assertEqual(res, {"event10"}) + self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) class EventCacheTestCase(unittest.HomeserverTestCase): @@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.user = self.register_user("user", "pass") self.token = self.login(self.user, "pass") @@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage.""" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() + self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" self.event_ids = [f"event{i}" for i in range(20)] diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index d326a1d6a..3ac464696 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -20,7 +20,7 @@ from tests import unittest class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 7496974da..9abd0cb44 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 200b9198f..4899cd5c3 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,7 @@ from tests import unittest class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.storage = hs.get_datastore() + self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) self.get_success( diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d697d2bc1..272cd3540 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ from tests import unittest class IgnoredUsersTestCase(unittest.HomeserverTestCase): def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = "@user:test" def _update_ignore_list( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ddcb7f554..ee599f433 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -88,21 +88,21 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") - self.assertEquals(service, None) + self.assertEqual(service, None) def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) assert stored_service is not None - self.assertEquals(stored_service.token, self.as_token) - self.assertEquals(stored_service.id, self.as_id) - self.assertEquals(stored_service.url, self.as_url) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) - self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) + self.assertEqual(stored_service.token, self.as_token) + self.assertEqual(stored_service.id, self.as_id) + self.assertEqual(stored_service.url, self.as_url) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEqual(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() - self.assertEquals(len(services), 3) + self.assertEqual(len(services), 3) class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): @@ -182,7 +182,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) -> None: service = Mock(id="999") state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(None, state) + self.assertEqual(None, state) def test_get_appservice_state_up( self, @@ -194,7 +194,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): state = self.get_success( defer.ensureDeferred(self.store.get_appservice_state(service)) ) - self.assertEquals(ApplicationServiceState.UP, state) + self.assertEqual(ApplicationServiceState.UP, state) def test_get_appservice_state_down( self, @@ -210,7 +210,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) service = Mock(id=self.as_list[1]["id"]) state = self.get_success(self.store.get_appservice_state(service)) - self.assertEquals(ApplicationServiceState.DOWN, state) + self.assertEqual(ApplicationServiceState.DOWN, state) def test_get_appservices_by_state_none( self, @@ -218,7 +218,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(0, len(services)) + self.assertEqual(0, len(services)) def test_set_appservices_state_down( self, @@ -235,7 +235,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.DOWN.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_set_appservices_state_multiple_up( self, @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (ApplicationServiceState.UP.value,), ) ) - self.assertEquals(service.id, rows[0][0]) + self.assertEqual(service.id, rows[0][0]) def test_create_appservice_txn_first( self, @@ -267,12 +267,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) - self.assertEquals(txn.id, 1) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 1) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_older_last_txn( self, @@ -283,11 +283,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9646) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9646) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_to_date_last_txn( self, @@ -296,11 +296,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_create_appservice_txn_up_fuzzing( self, @@ -320,11 +320,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) - self.assertEquals(txn.id, 9644) - self.assertEquals(txn.events, events) - self.assertEquals(txn.service, service) + self.assertEqual(txn.id, 9644) + self.assertEqual(txn.events, events) + self.assertEqual(txn.service, service) def test_complete_appservice_txn_first_txn( self, @@ -346,8 +346,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) res = self.get_success( self.db_pool.runQuery( @@ -357,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_complete_appservice_txn_existing_in_state_table( self, @@ -379,9 +379,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (service.id,), ) ) - self.assertEquals(1, len(res)) - self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) + self.assertEqual(1, len(res)) + self.assertEqual(txn_id, res[0][0]) + self.assertEqual(ApplicationServiceState.UP.value, res[0][1]) res = self.get_success( self.db_pool.runQuery( @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (txn_id,), ) ) - self.assertEquals(0, len(res)) + self.assertEqual(0, len(res)) def test_get_oldest_unsent_txn_none( self, @@ -399,7 +399,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(None, txn) + self.assertEqual(None, txn) def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) @@ -416,9 +416,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) - self.assertEquals(service, txn.service) - self.assertEquals(10, txn.id) - self.assertEquals(events, txn.events) + self.assertEqual(service, txn.service) + self.assertEqual(10, txn.id) + self.assertEqual(events, txn.events) def test_get_appservices_by_state_single( self, @@ -433,8 +433,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(1, len(services)) - self.assertEquals(self.as_list[0]["id"], services[0].id) + self.assertEqual(1, len(services)) + self.assertEqual(self.as_list[0]["id"], services[0].id) def test_get_appservices_by_state_multiple( self, @@ -455,8 +455,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) - self.assertEquals(2, len(services)) - self.assertEquals( + self.assertEqual(2, len(services)) + self.assertEqual( {self.as_list[2]["id"], self.as_list[0]["id"]}, {services[0].id, services[1].id}, ) @@ -467,7 +467,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.service = Mock(id="foo") - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.get_success( self.store.set_appservice_state(self.service, ApplicationServiceState.UP) ) @@ -476,12 +476,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEquals(value, 0) + self.assertEqual(value, 0) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6156dfac4..39dcc094b 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -24,7 +24,7 @@ from tests.test_utils import make_awaitable, simple_async_mock class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -42,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 100 - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", @@ -102,7 +102,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -138,7 +138,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): ) def test_controller(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main self.get_success( store.db_pool.simple_insert( "background_updates", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c..a8ffb52c0 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -103,7 +103,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals("Value", value) + self.assertEqual("Value", value) self.mock_txn.execute.assert_called_with( "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -121,7 +121,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -154,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index a59c28f89..ce89c9691 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -242,7 +242,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() self.event_creator_handler = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c8ac67e35..49ad3c132 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -35,7 +35,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main def test_insert_new_client_ip(self): self.reactor.advance(12345678) @@ -666,7 +666,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): return hs def prepare(self, hs, reactor, clock): - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user_id = self.register_user("bob", "abc123", True) def test_request_with_xforwarded(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index b547bf8d9..21ffc5a90 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_store_new_device(self): self.get_success( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 43628ce44..20bf3ca17 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase class DirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") @@ -31,7 +31,7 @@ class DirectoryStoreTestCase(HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( ["#my-room:test"], (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 7556171d8..fb96ab3a2 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -28,7 +28,7 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver("server", federation_http_client=None) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main return hs def test_room_keys_version_delete(self): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 3bf6e337f..0f04493ad 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -17,7 +17,7 @@ from tests.unittest import HomeserverTestCase class EndToEndKeyStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_key_without_device_name(self): now = 1470174257070 diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e3273a93f..401020fd6 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -30,7 +30,7 @@ from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): @@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 667ca90a4..645d564d1 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -31,7 +31,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): room_id = "@ROOM:local" diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 738f3ad1d..0f9add484 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -30,7 +30,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.persist_events_store = hs.get_datastores().persist_events def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -57,7 +57,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) ) - self.assertEquals( + self.assertEqual( counts, NotifCounts( notify_count=noitf_count, diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f462a8b1c..ef5e25873 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self.persistence = self.hs.get_storage().persistence - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.register_user("user", "pass") self.token = self.login("user", "pass") @@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) + + +class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastores().main + + def test_remote_user_rooms_cache_invalidated(self): + """Test that if the server leaves a room the `get_rooms_for_user` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_rooms_for_user` to add the remote user to the cache + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), {room_id}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), set()) + + def test_room_remote_user_cache_invalidated(self): + """Test that if the server leaves a room the `get_users_in_room` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_users_in_room` to add the remote user to the cache + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(set(users), {user_id, remote_user}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(users, []) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 748607828..6ac4b93f9 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): skip = "Requires Postgres" def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index a94b5fd72..905909552 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -37,7 +37,7 @@ KEY_2 = decode_verify_key_base64( class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" @@ -74,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_cache(self): """Check that updates correctly invalidate the cache.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4..5806cb0e4 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: super(DataStoreTestCase, self).setUp() - self.store = self.hs.get_datastore() + self.store = self.hs.get_datastores().main self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" @@ -38,12 +38,12 @@ class DataStoreTestCase(unittest.HomeserverTestCase): self.store.get_users_paginate(0, 10, name="bc", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) - self.assertEquals(1, total) - self.assertEquals(self.displayname, users.pop()["displayname"]) + self.assertEqual(1, total) + self.assertEqual(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6b4cdd78..79648d45d 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -45,7 +45,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main # Advance the clock a bit reactor.advance(FORTY_DAYS) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index d37736edf..a019d06e0 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -22,7 +22,7 @@ from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_frank = UserID.from_string("@frank:test") @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) - self.assertEquals( + self.assertEqual( "Frank", ( self.get_success( @@ -60,7 +60,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals( + self.assertEqual( "http://my.site/here", ( self.get_success( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 22a77c3cc..08cc60237 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -30,7 +30,7 @@ class PurgeTests(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = self.hs.get_storage() def test_purge_history(self): @@ -47,7 +47,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - token_str = self.get_success(token.to_string(self.hs.get_datastore())) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8c95a0a2f..03e9cc7d4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -30,7 +30,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 974806528..a49ac1525 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] @@ -30,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): def test_register(self): self.get_success(self.store.register_user(self.user_id, self.pwhash)) - self.assertEquals( + self.assertEqual( { # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, @@ -131,7 +131,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Unknown session_id", e) + self.assertEqual(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True @@ -146,4 +146,4 @@ class RegistrationStoreTestCase(HomeserverTestCase): ), ThreepidValidationError, ) - self.assertEquals(e.value.msg, "Validation token not found or has expired", e) + self.assertEqual(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index cfc8098af..0baa54312 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -56,7 +56,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_rolling_back(self): """Test that workers can start if the DB is a newer schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -72,7 +72,7 @@ class WorkerSchemaTests(HomeserverTestCase): def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, @@ -92,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase): Test that workers don't start if the DB is on the current schema version, but there are still outstanding delta migrations to run. """ - db_pool = self.hs.get_datastore().db_pool + db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( db_pool._db_pool.connect(), db_pool.engine, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 31ce7f625..5b011e18c 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#a-room-name:test") @@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() @@ -104,7 +104,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, state[0], @@ -121,7 +121,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): self.store.get_current_state(room_id=self.room.to_string()) ) - self.assertEquals(1, len(state)) + self.assertEqual(1, len(state)) self.assertObjectHasAttributes( {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, state[0], diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8971ecccb..8dfc1e1db 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -13,13 +13,16 @@ # limitations under the License. import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError from synapse.rest.client import login, room from synapse.storage.engines import PostgresEngine -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, skip_unless +from tests.utils import USE_POSTGRES_FOR_TESTS -class NullByteInsertionTest(HomeserverTestCase): +class EventSearchInsertionTest(HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -46,11 +49,11 @@ class NullByteInsertionTest(HomeserverTestCase): self.assertIn("event_id", response) # Check that search works for the message where the null byte was replaced - store = self.hs.get_datastore() + store = self.hs.get_datastores().main result = self.get_success( store.search_msgs([room_id], "hi bob", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("hi", result.get("highlights")) self.assertIn("bob", result.get("highlights")) @@ -59,16 +62,126 @@ class NullByteInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) - self.assertEquals(result.get("count"), 1) + self.assertEqual(result.get("count"), 1) if isinstance(store.database_engine, PostgresEngine): self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"])) - self.assertEquals(result.get("count"), 2) + self.assertEqual(result.get("count"), 2) result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) if isinstance(store.database_engine, PostgresEngine): self.assertIn("alice", result.get("highlights")) + + def test_non_string(self): + """Test that non-string `value`s are not inserted into `event_search`. + + This is particularly important when using sqlite, since a sqlite column can hold + both strings and integers. When using Postgres, integers are automatically + converted to strings. + + Regression test for #11918. + """ + store = self.hs.get_datastores().main + + # Register a user and create a room + user_id = self.register_user("alice", "password") + access_token = self.login("alice", "password") + room_id = self.helper.create_room_as("alice", tok=access_token) + room_version = self.get_success(store.get_room_version(room_id)) + + # Construct a message with a numeric body to be received over federation + # The message can't be sent using the client API, since Synapse's event + # validation will reject it. + prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + prev_event = self.get_success(store.get_event(prev_event_ids[0])) + prev_state_map = self.get_success( + self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + ) + + event_dict = { + "type": EventTypes.Message, + "content": {"msgtype": "m.text", "body": 2}, + "room_id": room_id, + "sender": user_id, + "depth": prev_event.depth + 1, + "prev_events": prev_event_ids, + "origin_server_ts": self.clock.time_msec(), + } + builder = self.hs.get_event_builder_factory().for_room_version( + room_version, event_dict + ) + event = self.get_success( + builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=self.hs.get_event_auth_handler().compute_auth_events( + builder, + prev_state_map, + for_verification=False, + ), + depth=event_dict["depth"], + ) + ) + + # Receive the event + self.get_success( + self.hs.get_federation_event_handler().on_receive_pdu( + self.hs.hostname, event + ) + ) + + # The event should not have an entry in the `event_search` table + f = self.get_failure( + store.db_pool.simple_select_one_onecol( + "event_search", + {"room_id": room_id, "event_id": event.event_id}, + "event_id", + ), + StoreError, + ) + self.assertEqual(f.value.code, 404) + + @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite") + def test_sqlite_non_string_deletion_background_update(self): + """Test the background update to delete bad rows from `event_search`.""" + store = self.hs.get_datastores().main + + # Populate `event_search` with dummy data + self.get_success( + store.db_pool.simple_insert_many( + "event_search", + keys=["event_id", "room_id", "key", "value"], + values=[ + ("event1", "room_id", "content.body", "hi"), + ("event2", "room_id", "content.body", "2"), + ("event3", "room_id", "content.body", 3), + ], + desc="populate_event_search", + ) + ) + + # Run the background update + store.db_pool.updates._all_done = False + self.get_success( + store.db_pool.simple_insert( + "background_updates", + { + "update_name": "event_search_sqlite_delete_non_strings", + "progress_json": "{}", + }, + ) + ) + self.wait_for_background_updates() + + # The non-string `value`s ought to be gone now. + values = self.get_success( + store.db_pool.simple_select_onecol( + "event_search", + {"room_id": "room_id"}, + "value", + ), + ) + self.assertCountEqual(values, ["hi", "2"]) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b8..b8f09a8ee 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -35,7 +35,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,7 +55,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) ) - self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + self.assertEqual([self.room], [m.room_id for m in rooms_for_user]) def test_count_known_servers(self): """ @@ -212,7 +212,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() def test_can_rerun_update(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088..f88f1c55f 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() @@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index ce782c7e1..6a1cf3305 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -115,7 +115,7 @@ class PaginationTestCase(HomeserverTestCase): ) events, next_key = self.get_success( - self.hs.get_datastore().paginate_room_events( + self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, from_key=from_token.room_key, to_key=None, diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index bea9091d3..e05daa285 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase class TransactionStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_get_set_transactions(self): """Tests that we can successfully get a non-existent entry for diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 48f1e9d84..7f1964eb6 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -149,7 +149,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main self.user_dir_helper = GetUserDirectoryTables(self.store) def _purge_and_rebuild_user_dir(self) -> None: @@ -415,7 +415,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastore() + self.store = hs.get_datastores().main # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_distributor.py b/tests/test_distributor.py index f8341041e..31546ea52 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -48,7 +48,7 @@ class DistributorTestCase(unittest.TestCase): observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") - self.assertEquals(mock_logger.warning.call_count, 1) + self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) def test_signal_prereg(self): diff --git a/tests/test_federation.py b/tests/test_federation.py index 2b9804aba..c39816de8 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -52,11 +52,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) )[0]["room_id"] - self.store = self.homeserver.get_datastore() + self.store = self.homeserver.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) + self.homeserver.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) )[0] join_event = make_event_from_dict( @@ -185,7 +187,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastore() + store = self.homeserver.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e25..46bd3075d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -52,7 +52,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastore() + self.store = homeserver.get_datastores().main def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated diff --git a/tests/test_state.py b/tests/test_state.py index 76e0e8ca7..e4baa6913 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Collection, Dict, List, Optional from unittest.mock import Mock from twisted.internet import defer @@ -70,7 +70,7 @@ def create_event( return event -class StateGroupStore: +class _DummyStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -105,6 +105,11 @@ class StateGroupStore: if e_id in self._event_id_to_event } + async def get_partial_state_events( + self, event_ids: Collection[str] + ) -> Dict[str, bool]: + return {e: False for e in event_ids} + async def get_state_group_delta(self, name): return None, None @@ -157,12 +162,12 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): - self.store = StateGroupStore() - storage = Mock(main=self.store, state=self.store) + self.dummy_store = _DummyStore() + storage = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", - "get_datastore", + "get_datastores", "get_storage", "get_auth", "get_state_handler", @@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase): ] ) hs.config = default_config("tesths", True) - hs.get_datastore.return_value = self.store + hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) @@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store: dict[str, EventContext] = {} @@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] @@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C @@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase): edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other @@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - self.store.register_events(graph.walk()) + self.dummy_store.register_events(graph.walk()) context_store = {} @@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred( self.state.compute_event_context(event) ) - self.store.register_event_context(event, context) + self.dummy_store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels @@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase): ] group_name = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id, event.room_id, None, @@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state}, ) ) - self.store.register_event_id_state_group(prev_event_id, group_name) + self.dummy_store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred(self.state.compute_event_context(event)) @@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - self.store.register_events(old_state_1) - self.store.register_events(old_state_2) + self.dummy_store.register_events(old_state_1) + self.dummy_store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] - store = StateGroupStore() + store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.store.get_events = store.get_events + self.dummy_store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 @@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase): self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_1, event.room_id, None, @@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_1}, ) ) - self.store.register_event_id_state_group(prev_event_id_1, sg1) + self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( - self.store.store_state_group( + self.dummy_store.store_state_group( prev_event_id_2, event.room_id, None, @@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase): {(e.type, e.state_key): e.event_id for e in old_state_2}, ) ) - self.store.register_event_id_state_group(prev_event_id_2, sg2) + self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 67dcf567c..37fada5c5 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -54,7 +54,7 @@ class TermsTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["session"], str) @@ -99,7 +99,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. - self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms request_data = json.dumps( @@ -117,7 +117,7 @@ class TermsTestCase(unittest.HomeserverTestCase): # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. - self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) self.assertIsInstance(channel.json_body["user_id"], str) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index f2ef1c605..d04bcae0f 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -25,7 +25,7 @@ class MockClockTestCase(unittest.TestCase): self.clock.advance_time(20) - self.assertEquals(20, self.clock.time() - start_time) + self.assertEqual(20, self.clock.time() - start_time) def test_later(self): invoked = [0, 0] diff --git a/tests/test_types.py b/tests/test_types.py index 0d0c00d97..80888a744 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -22,9 +22,9 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): user = UserID.from_string("@1234abcd:test") - self.assertEquals("1234abcd", user.localpart) - self.assertEquals("test", user.domain) - self.assertEquals(True, self.hs.is_mine(user)) + self.assertEqual("1234abcd", user.localpart) + self.assertEqual("test", user.domain) + self.assertEqual(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -33,7 +33,7 @@ class UserIDTestCase(unittest.HomeserverTestCase): def test_build(self): user = UserID("5678efgh", "my.domain") - self.assertEquals(user.to_string(), "@5678efgh:my.domain") + self.assertEqual(user.to_string(), "@5678efgh:my.domain") def test_compare(self): userA = UserID.from_string("@userA:my.domain") @@ -48,14 +48,14 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): room = RoomAlias.from_string("#channel:test") - self.assertEquals("channel", room.localpart) - self.assertEquals("test", room.domain) - self.assertEquals(True, self.hs.is_mine(room)) + self.assertEqual("channel", room.localpart) + self.assertEqual("test", room.domain) + self.assertEqual(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") - self.assertEquals(room.to_string(), "#channel:my.domain") + self.assertEqual(room.to_string(), "#channel:my.domain") def test_validate(self): id_string = "#test:domain,test" diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e9ec9e085..c654e36ee 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -85,7 +85,9 @@ async def create_event( **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: - room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) + room_version = await hs.get_datastores().main.get_room_version_id( + kwargs["room_id"] + ) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d..219b5660b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -93,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + self.get_success( + self.hs.get_datastores().main.mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index a71892cb9..326895f4c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -51,7 +51,10 @@ from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite @@ -149,12 +152,12 @@ class TestCase(unittest.TestCase): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and - that the value of each matches according to assertEquals.""" + that the value of each matches according to assertEqual.""" for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: - self.assertEquals(attrs[key], getattr(obj, key)) + self.assertEqual(attrs[key], getattr(obj, key)) except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e @@ -166,7 +169,7 @@ class TestCase(unittest.TestCase): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals( + self.assertEqual( required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) @@ -277,7 +280,7 @@ class HomeserverTestCase(TestCase): # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( - self.hs.get_datastore().add_access_token_to_user( + self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, "some_fake_token", None, @@ -334,7 +337,7 @@ class HomeserverTestCase(TestCase): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): @@ -501,7 +504,7 @@ class HomeserverTestCase(TestCase): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) - stor = hs.get_datastore() + stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": @@ -719,14 +722,16 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db_pool.simple_insert( + self.hs.get_datastores().main.db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", ) ) - self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting @@ -772,7 +777,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastore().store_server_verify_keys( + hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ @@ -839,6 +844,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) + def add_hashes_and_signatures( + self, + event_dict: JsonDict, + room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) -> JsonDict: + """Adds hashes and signatures to the given event dict + + Returns: + The modified event dict, for convenience + """ + add_hashes_and_signatures( + room_version, + event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + return event_dict + def _auth_header_for_request( origin: str, diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index c613ce3f1..02b99b466 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -31,7 +31,7 @@ class DeferredCacheTestCase(TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(self.successResultOf(cache.get("foo")), 123) + self.assertEqual(self.successResultOf(cache.get("foo")), 123) def test_hit_deferred(self): cache = DeferredCache("test") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index ced3efd93..19741ffcd 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -434,8 +434,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual((yield a.func("bar")), "bar") @defer.inlineCallbacks def test_hit(self): @@ -450,10 +450,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) + self.assertEqual((yield a.func("foo")), "foo") + self.assertEqual(callcount[0], 1) @defer.inlineCallbacks def test_invalidate(self): @@ -468,13 +468,13 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func("foo") - self.assertEquals(callcount[0], 1) + self.assertEqual(callcount[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) + self.assertEqual(callcount[0], 2) def test_invalidate_missing(self): class A: @@ -499,7 +499,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): for k in range(0, 12): yield a.func(k) - self.assertEquals(callcount[0], 12) + self.assertEqual(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times @@ -525,8 +525,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, 456) - self.assertEquals(callcount[0], 0) + self.assertEqual(a.func("foo").result, 456) + self.assertEqual(callcount[0], 0) @defer.inlineCallbacks def test_invalidate_context(self): @@ -547,19 +547,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func.invalidate(("foo",)) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 1) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) @defer.inlineCallbacks def test_eviction_context(self): @@ -581,22 +581,22 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") yield a.func2("foo2") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func("foo3") - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 3) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 4) + self.assertEqual(callcount2[0], 3) @defer.inlineCallbacks def test_double_get(self): @@ -619,30 +619,30 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): yield a.func2("foo") - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2) - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 1) + self.assertEqual(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) + self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 2) yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) + self.assertEqual(callcount[0], 2) + self.assertEqual(callcount2[0], 3) class CachedListDescriptorTestCase(unittest.TestCase): @@ -673,14 +673,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with((30,), 2) + obj.mock.assert_called_once_with({30}, 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -701,7 +701,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.return_value = {40: "gravy"} iterable = (x for x in [10, 40, 40]) r = yield obj.list_fn(iterable, 2) - obj.mock.assert_called_once_with((40,), 2) + obj.mock.assert_called_once_with({40}, 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) def test_concurrent_lookups(self): @@ -729,7 +729,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): d3 = obj.list_fn([10]) # the mock should have been called exactly once - obj.mock.assert_called_once_with((10,)) + obj.mock.assert_called_once_with({10}) obj.mock.reset_mock() # ... and none of the calls should yet be complete @@ -771,7 +771,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with((10, 20), 2) + obj.mock.assert_called_once_with({10, 20}, 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ab89cab81..362014f4c 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock +from twisted.python.failure import Failure from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -21,7 +24,12 @@ from synapse.logging.context import ( PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.async_helpers import ( + ObservableDeferred, + concurrently_execute, + stop_cancellation, + timeout_deferred, +) from tests.unittest import TestCase @@ -171,3 +179,151 @@ class TimeoutDeferredTest(TestCase): ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) + + +class _TestException(Exception): + pass + + +class ConcurrentlyExecuteTest(TestCase): + def test_limits_runners(self): + """If we have more tasks than runners, we should get the limit of runners""" + started = 0 + waiters = [] + processed = [] + + async def callback(v): + # when we first enter, bump the start count + nonlocal started + started += 1 + + # record the fact we got an item + processed.append(v) + + # wait for the goahead before returning + d2 = Deferred() + waiters.append(d2) + await d2 + + # set it going + d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3)) + + # check we got exactly 3 processes + self.assertEqual(started, 3) + self.assertEqual(len(waiters), 3) + + # let one finish + waiters.pop().callback(0) + + # ... which should start another + self.assertEqual(started, 4) + self.assertEqual(len(waiters), 3) + + # we still shouldn't be done + self.assertNoResult(d2) + + # finish the job + while waiters: + waiters.pop().callback(0) + + # check everything got done + self.assertEqual(started, 5) + self.assertCountEqual(processed, [1, 2, 3, 4, 5]) + self.successResultOf(d2) + + def test_preserves_stacktraces(self): + """Test that the stacktrace from an exception thrown in the callback is preserved""" + d1 = Deferred() + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + raise _TestException("bah") + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute" and "callback". + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-1].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + def test_preserves_stacktraces_on_preformed_failure(self): + """Test that the stacktrace on a Failure returned by the callback is preserved""" + d1 = Deferred() + f = Failure(_TestException("bah")) + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + await defer.fail(f) + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute", "callback", + # and some magic from inside ensureDeferred that happens when .fail + # is called. + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-2].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + + def test_succeed(self): + """Test that the new `Deferred` receives the result.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Success should propagate through. + deferred.callback("success") + self.assertTrue(wrapper_deferred.called) + self.assertEqual("success", self.successResultOf(wrapper_deferred)) + + def test_failure(self): + """Test that the new `Deferred` receives the `Failure`.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Failure should propagate through. + deferred.errback(ValueError("abc")) + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, ValueError) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` leaves the original running.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, CancelledError) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled." + ) + + # Now make the inner `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py new file mode 100644 index 000000000..3c0725225 --- /dev/null +++ b/tests/util/test_check_dependencies.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from typing import Generator, Optional +from unittest.mock import patch + +from synapse.util.check_dependencies import ( + DependencyException, + check_requirements, + metadata, +) + +from tests.unittest import TestCase + + +class DummyDistribution(metadata.Distribution): + def __init__(self, version: str): + self._version = version + + @property + def version(self): + return self._version + + def locate_file(self, path): + raise NotImplementedError() + + def read_text(self, filename): + raise NotImplementedError() + + +old = DummyDistribution("0.1.2") +new = DummyDistribution("1.2.3") + +# could probably use stdlib TestCase --- no need for twisted here + + +class TestDependencyChecker(TestCase): + @contextmanager + def mock_installed_package( + self, distribution: Optional[DummyDistribution] + ) -> Generator[None, None, None]: + """Pretend that looking up any distribution yields the given `distribution`.""" + + def mock_distribution(name: str): + if distribution is None: + raise metadata.PackageNotFoundError + else: + return distribution + + with patch( + "synapse.util.check_dependencies.metadata.distribution", + mock_distribution, + ): + yield + + def test_mandatory_dependency(self) -> None: + """Complain if a required package is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1"], + ): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_generic_check_of_optional_dependency(self) -> None: + """Complain if an optional package is old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ): + with self.mock_installed_package(None): + # should not raise + check_requirements() + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements) + with self.mock_installed_package(new): + # should not raise + check_requirements() + + def test_check_for_extra_dependencies(self) -> None: + """Complain if a package required for an extra is missing or old.""" + with patch( + "synapse.util.check_dependencies.metadata.requires", + return_value=["dummypkg >= 1; extra == 'cool-extra'"], + ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}): + with self.mock_installed_package(None): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(old): + self.assertRaises(DependencyException, check_requirements, "cool-extra") + with self.mock_installed_package(new): + # should not raise + check_requirements() diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index e6e13ba06..7f60aae5b 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -26,8 +26,8 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache = ExpiringCache("test", clock, max_len=1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): clock = MockClock() @@ -35,13 +35,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key"] = "value" cache["key2"] = "value2" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache.get("key2"), "value2") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache.get("key2"), "value2") cache["key3"] = "value3" - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), "value2") - self.assertEquals(cache.get("key3"), "value3") + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), "value2") + self.assertEqual(cache.get("key3"), "value3") def test_iterable_eviction(self): clock = MockClock() @@ -51,15 +51,15 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): cache["key2"] = [2, 3] cache["key3"] = [4, 5] - self.assertEquals(cache.get("key"), [1]) - self.assertEquals(cache.get("key2"), [2, 3]) - self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key"), [1]) + self.assertEqual(cache.get("key2"), [2, 3]) + self.assertEqual(cache.get("key3"), [4, 5]) cache["key4"] = [6, 7] - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache.get("key3"), [4, 5]) - self.assertEquals(cache.get("key4"), [6, 7]) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache.get("key3"), [4, 5]) + self.assertEqual(cache.get("key4"), [6, 7]) def test_time_eviction(self): clock = MockClock() @@ -69,13 +69,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): clock.advance_time(0.5) cache["key2"] = 2 - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(0.9) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), 2) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), 2) clock.advance_time(1) - self.assertEquals(cache.get("key"), None) - self.assertEquals(cache.get("key2"), None) + self.assertEqual(cache.get("key"), None) + self.assertEqual(cache.get("key2"), None) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 621b0f9fc..2ad321e18 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -17,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().name, value) + self.assertEqual(current_context().name, value) def test_with_context(self): with LoggingContext("test"): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7..321fc1776 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -27,37 +27,37 @@ class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + self.assertEqual(cache.get("key"), "value") + self.assertEqual(cache["key"], "value") def test_eviction(self): cache = LruCache(2) cache[1] = 1 cache[2] = 2 - self.assertEquals(cache.get(1), 1) - self.assertEquals(cache.get(2), 2) + self.assertEqual(cache.get(1), 1) + self.assertEqual(cache.get(2), 2) cache[3] = 3 - self.assertEquals(cache.get(1), None) - self.assertEquals(cache.get(2), 2) - self.assertEquals(cache.get(3), 3) + self.assertEqual(cache.get(1), None) + self.assertEqual(cache.get(2), 2) + self.assertEqual(cache.get(3), 3) def test_setdefault(self): cache = LruCache(1) - self.assertEquals(cache.setdefault("key", 1), 1) - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.setdefault("key", 2), 1) - self.assertEquals(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 1), 1) + self.assertEqual(cache.get("key"), 1) + self.assertEqual(cache.setdefault("key", 2), 1) + self.assertEqual(cache.get("key"), 1) cache["key"] = 2 # Make sure overriding works. - self.assertEquals(cache.get("key"), 2) + self.assertEqual(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) cache["key"] = 1 - self.assertEquals(cache.pop("key"), 1) - self.assertEquals(cache.pop("key"), None) + self.assertEqual(cache.pop("key"), 1) + self.assertEqual(cache.pop("key"), None) def test_del_multi(self): cache = LruCache(4, cache_type=TreeCache) @@ -66,23 +66,23 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache[("vehicles", "car")] = "vroom" cache[("vehicles", "train")] = "chuff" - self.assertEquals(len(cache), 4) + self.assertEqual(len(cache), 4) - self.assertEquals(cache.get(("animal", "cat")), "mew") - self.assertEquals(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("animal", "cat")), "mew") + self.assertEqual(cache.get(("vehicles", "car")), "vroom") cache.del_multi(("animal",)) - self.assertEquals(len(cache), 2) - self.assertEquals(cache.get(("animal", "cat")), None) - self.assertEquals(cache.get(("animal", "dog")), None) - self.assertEquals(cache.get(("vehicles", "car")), "vroom") - self.assertEquals(cache.get(("vehicles", "train")), "chuff") + self.assertEqual(len(cache), 2) + self.assertEqual(cache.get(("animal", "cat")), None) + self.assertEqual(cache.get(("animal", "dog")), None) + self.assertEqual(cache.get(("vehicles", "car")), "vroom") + self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". def test_clear(self): cache = LruCache(1) cache["key"] = 1 cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self): @@ -105,10 +105,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_multi_get(self): m = Mock() @@ -124,10 +124,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_set(self): m = Mock() @@ -140,10 +140,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.set("key", "value2") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_pop(self): m = Mock() @@ -153,13 +153,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertFalse(m.called) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.set("key", "value") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) cache.pop("key") - self.assertEquals(m.call_count, 1) + self.assertEqual(m.call_count, 1) def test_del_multi(self): m1 = Mock() @@ -173,17 +173,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set(("b", "1"), "value", callbacks=[m3]) cache.set(("b", "2"), "value", callbacks=[m4]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) cache.del_multi(("a",)) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) - self.assertEquals(m3.call_count, 0) - self.assertEquals(m4.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) + self.assertEqual(m3.call_count, 0) + self.assertEqual(m4.call_count, 0) def test_clear(self): m1 = Mock() @@ -193,13 +193,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) cache.clear() - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 1) def test_eviction(self): m1 = Mock(name="m1") @@ -210,33 +210,33 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) - self.assertEquals(m1.call_count, 0) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 0) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value", callbacks=[m3]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key3", "value") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.get("key2") - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 0) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 0) cache.set("key1", "value", callbacks=[m1]) - self.assertEquals(m1.call_count, 1) - self.assertEquals(m2.call_count, 0) - self.assertEquals(m3.call_count, 1) + self.assertEqual(m1.call_count, 1) + self.assertEqual(m2.call_count, 0) + self.assertEqual(m3.call_count, 1) class LruCacheSizedTestCase(unittest.HomeserverTestCase): @@ -247,20 +247,20 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): cache["key3"] = [3] cache["key4"] = [4] - self.assertEquals(cache["key1"], [0]) - self.assertEquals(cache["key2"], [1, 2]) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(len(cache), 5) + self.assertEqual(cache["key1"], [0]) + self.assertEqual(cache["key2"], [1, 2]) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(len(cache), 5) cache["key5"] = [5, 6] - self.assertEquals(len(cache), 4) - self.assertEquals(cache.get("key1"), None) - self.assertEquals(cache.get("key2"), None) - self.assertEquals(cache["key3"], [3]) - self.assertEquals(cache["key4"], [4]) - self.assertEquals(cache["key5"], [5, 6]) + self.assertEqual(len(cache), 4) + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + self.assertEqual(cache["key3"], [3]) + self.assertEqual(cache["key4"], [4]) + self.assertEqual(cache["key5"], [5, 6]) def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e1bebdc8..26cb71c64 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -24,7 +24,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request @@ -38,7 +38,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_limiter(self): """General test case which walks through the process of a failing request""" - store = self.hs.get_datastore() + store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index a10071c70..0774625b8 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # limitations under the License. from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.util.async_helpers import ReadWriteLock @@ -83,3 +84,32 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d.called) with d.result: pass + + def test_lock_handoff_to_nonblocking_writer(self): + """Test a writer handing the lock to another writer that completes instantly.""" + rwlock = ReadWriteLock() + key = "key" + + unblock: "Deferred[None]" = Deferred() + + async def blocking_write(): + with await rwlock.write(key): + await unblock + + async def nonblocking_write(): + with await rwlock.write(key): + pass + + d1 = defer.ensureDeferred(blocking_write()) + d2 = defer.ensureDeferred(nonblocking_write()) + self.assertFalse(d1.called) + self.assertFalse(d2.called) + + # Unblock the first writer. The second writer will complete without blocking. + unblock.callback(None) + self.assertTrue(d1.called) + self.assertTrue(d2.called) + + # The `ReadWriteLock` should operate as normal. + d3 = defer.ensureDeferred(nonblocking_write()) + self.assertTrue(d3.called) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 606637205..567cb1846 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -23,61 +23,61 @@ class TreeCacheTestCase(unittest.TestCase): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.get(("a",)), "A") - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 2) + self.assertEqual(cache.get(("a",)), "A") + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 2) def test_pop_onelevel(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" - self.assertEquals(cache.pop(("a",)), "A") - self.assertEquals(cache.pop(("a",)), None) - self.assertEquals(cache.get(("b",)), "B") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a",)), "A") + self.assertEqual(cache.pop(("a",)), None) + self.assertEqual(cache.get(("b",)), "B") + self.assertEqual(len(cache), 1) def test_get_set_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 3) + self.assertEqual(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 3) def test_pop_twolevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.pop(("a", "a")), "AA") - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), "AB") - self.assertEquals(cache.pop(("b", "a")), "BA") - self.assertEquals(cache.pop(("b", "a")), None) - self.assertEquals(len(cache), 1) + self.assertEqual(cache.pop(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), "AB") + self.assertEqual(cache.pop(("b", "a")), "BA") + self.assertEqual(cache.pop(("b", "a")), None) + self.assertEqual(len(cache), 1) def test_pop_mixedlevel(self): cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" - self.assertEquals(cache.get(("a", "a")), "AA") + self.assertEqual(cache.get(("a", "a")), "AA") popped = cache.pop(("a",)) - self.assertEquals(cache.get(("a", "a")), None) - self.assertEquals(cache.get(("a", "b")), None) - self.assertEquals(cache.get(("b", "a")), "BA") - self.assertEquals(len(cache), 1) + self.assertEqual(cache.get(("a", "a")), None) + self.assertEqual(cache.get(("a", "b")), None) + self.assertEqual(cache.get(("b", "a")), "BA") + self.assertEqual(len(cache), 1) - self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) def test_clear(self): cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() - self.assertEquals(len(cache), 0) + self.assertEqual(len(cache), 0) def test_contains(self): cache = TreeCache() diff --git a/tests/utils.py b/tests/utils.py index c06fc320f..ef99c72e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage().persistence - store = hs.get_datastore() + store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() diff --git a/tox.ini b/tox.ini index 32679e910..04b972e2c 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 minversion = 2.3.2 +# the tox-venv plugin makes tox use python's built-in `venv` module rather than +# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`, +# etc, and ends up being rather unreliable. +requires = tox-venv + [base] deps = python-subunit @@ -119,6 +124,9 @@ usedevelop = false deps = Automat == 0.8.0 lxml + # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on + # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed. + markupsafe < 2.1 {[base]deps} commands = @@ -158,17 +166,7 @@ commands = [testenv:check_isort] extras = lint -commands = isort -c --df --sp setup.cfg {[base]lint_targets} - -[testenv:check-newsfragment] -skip_install = true -usedevelop = false -deps = towncrier>=18.6.0rc1 -commands = - python -m towncrier.check --compare-with=origin/develop - -[testenv:check-sampleconfig] -commands = {toxinidir}/scripts-dev/generate_sample_config --check +commands = isort -c --df {[base]lint_targets} [testenv:combine] skip_install = true