Merge remote-tracking branch 'upstream/release-v1.65'

This commit is contained in:
Tulir Asokan 2022-08-09 15:00:04 +03:00
commit 18fea0e69c
104 changed files with 2364 additions and 1200 deletions

View file

@ -135,11 +135,42 @@ jobs:
/logs/**/*.log* /logs/**/*.log*
# TODO: run complement (as with twisted trunk, see #12473). complement:
if: "${{ !failure() && !cancelled() }}"
runs-on: ubuntu-latest
# open an issue if the build fails, so we know about it. strategy:
fail-fast: false
matrix:
include:
- arrangement: monolith
database: SQLite
- arrangement: monolith
database: Postgres
- arrangement: workers
database: Postgres
steps:
- name: Run actions/checkout@v2 for synapse
uses: actions/checkout@v2
with:
path: synapse
- name: Prepare Complement's Prerequisites
run: synapse/.ci/scripts/setup_complement_prerequisites.sh
- run: |
set -o pipefail
TEST_ONLY_IGNORE_POETRY_LOCKFILE=1 POSTGRES=${{ (matrix.database == 'Postgres') && 1 || '' }} WORKERS=${{ (matrix.arrangement == 'workers') && 1 || '' }} COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
# Open an issue if the build fails, so we know about it.
# Only do this if we're not experimenting with this action in a PR.
open-issue: open-issue:
if: failure() if: "failure() && github.event_name != 'push' && github.event_name != 'pull_request'"
needs: needs:
# TODO: should mypy be included here? It feels more brittle than the other two. # TODO: should mypy be included here? It feels more brittle than the other two.
- mypy - mypy

View file

@ -328,6 +328,9 @@ jobs:
- arrangement: monolith - arrangement: monolith
database: Postgres database: Postgres
- arrangement: workers
database: Postgres
steps: steps:
- name: Run actions/checkout@v2 for synapse - name: Run actions/checkout@v2 for synapse
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -343,30 +346,6 @@ jobs:
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
# XXX When complement with workers is stable, move this back into the standard
# "complement" matrix above.
#
# See https://github.com/matrix-org/synapse/issues/13161
complement-workers:
if: "${{ !failure() && !cancelled() }}"
needs: linting-done
runs-on: ubuntu-latest
steps:
- name: Run actions/checkout@v2 for synapse
uses: actions/checkout@v2
with:
path: synapse
- name: Prepare Complement's Prerequisites
run: synapse/.ci/scripts/setup_complement_prerequisites.sh
- run: |
set -o pipefail
POSTGRES=1 WORKERS=1 COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
# a job which marks all the other jobs as complete, thus allowing PRs to be merged. # a job which marks all the other jobs as complete, thus allowing PRs to be merged.
tests-done: tests-done:
if: ${{ always() }} if: ${{ always() }}

View file

@ -1,3 +1,78 @@
Synapse 1.65.0rc1 (2022-08-09)
==============================
Features
--------
- Add support for stable prefixes for [MSC2285 (private read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). ([\#13273](https://github.com/matrix-org/synapse/issues/13273))
- Add new unstable error codes `ORG.MATRIX.MSC3848.ALREADY_JOINED`, `ORG.MATRIX.MSC3848.NOT_JOINED`, and `ORG.MATRIX.MSC3848.INSUFFICIENT_POWER` described in [MSC3848](https://github.com/matrix-org/matrix-spec-proposals/pull/3848). ([\#13343](https://github.com/matrix-org/synapse/issues/13343))
- Use stable prefixes for [MSC3827](https://github.com/matrix-org/matrix-spec-proposals/pull/3827). ([\#13370](https://github.com/matrix-org/synapse/issues/13370))
- Add a new module API method to translate a room alias into a room ID. ([\#13428](https://github.com/matrix-org/synapse/issues/13428))
- Add a new module API method to create a room. ([\#13429](https://github.com/matrix-org/synapse/issues/13429))
- Add remote join capability to the module API's `update_room_membership` method (in a backwards compatible manner). ([\#13441](https://github.com/matrix-org/synapse/issues/13441))
Bugfixes
--------
- Update the version of the LDAP3 auth provider module included in the `matrixdotorg/synapse` DockerHub images and the Debian packages
- Fix a bug introduced in Synapse v1.41.0 where the `/hierarchy` API returned non-standard information (a `room_id` field under each entry in `children_state`). ([\#13365](https://github.com/matrix-org/synapse/issues/13365))
- Fix a bug introduced in Synapse 0.24.0 that would respond with the wrong error status code to `/joined_members` requests when the requester is not a current member of the room. Contributed by @andrewdoh. ([\#13374](https://github.com/matrix-org/synapse/issues/13374))
- Fix bug in handling of typing events for appservices. Contributed by Nick @ Beeper (@fizzadar). ([\#13392](https://github.com/matrix-org/synapse/issues/13392))
- Fix a bug introduced in Synapse 1.57.0 where rooms listed in `exclude_rooms_from_sync` in the configuration file would not be properly excluded from incremental syncs. ([\#13408](https://github.com/matrix-org/synapse/issues/13408))
- Fix a bug in the experimental faster-room-joins support which could cause it to get stuck in an infinite loop. ([\#13353](https://github.com/matrix-org/synapse/issues/13353))
- Faster room joins: fix a bug which caused rejected events to become un-rejected during state syncing. ([\#13413](https://github.com/matrix-org/synapse/issues/13413))
- Faster room joins: fix error when running out of servers to sync partial state with, so that Synapse raises the intended error instead. ([\#13432](https://github.com/matrix-org/synapse/issues/13432))
hosted on packages.matrix.org to 0.2.2. This version fixes a regression in the module. ([\#13470](https://github.com/matrix-org/synapse/issues/13470))
Updates to the Docker image
---------------------------
- Make Docker images build on armv7 by installing cryptography dependencies in the 'requirements' stage. Contributed by Jasper Spaans. ([\#13372](https://github.com/matrix-org/synapse/issues/13372))
Improved Documentation
----------------------
- Update the 'registration tokens' page to acknowledge that the relevant MSC was merged into version 1.2 of the Matrix specification. Contributed by @moan0s. ([\#11897](https://github.com/matrix-org/synapse/issues/11897))
- Document which HTTP resources support gzip compression. ([\#13221](https://github.com/matrix-org/synapse/issues/13221))
- Add steps describing how to elevate an existing user to administrator by manipulating the database. ([\#13230](https://github.com/matrix-org/synapse/issues/13230))
- Fix wrong headline for `url_preview_accept_language` in documentation. ([\#13437](https://github.com/matrix-org/synapse/issues/13437))
- Remove redundant 'Contents' section from the Configuration Manual. Contributed by @dklimpel. ([\#13438](https://github.com/matrix-org/synapse/issues/13438))
- Update documentation for config setting `macaroon_secret_key`. ([\#13443](https://github.com/matrix-org/synapse/issues/13443))
- Update outdated information on `sso_mapping_providers` documentation. ([\#13449](https://github.com/matrix-org/synapse/issues/13449))
- Fix example code in module documentation of `password_auth_provider_callbacks`. ([\#13450](https://github.com/matrix-org/synapse/issues/13450))
- Make the configuration for the cache clearer. ([\#13481](https://github.com/matrix-org/synapse/issues/13481))
Internal Changes
----------------
- Extend the release script to automatically push a new SyTest branch, rather than having that be a manual process. ([\#12978](https://github.com/matrix-org/synapse/issues/12978))
- Make minor clarifications to the error messages given when we fail to join a room via any server. ([\#13160](https://github.com/matrix-org/synapse/issues/13160))
- Enable Complement CI tests in the 'latest deps' test run. ([\#13213](https://github.com/matrix-org/synapse/issues/13213))
- Fix long-standing bugged logic which was never hit in `get_pdu` asking every remote destination even after it finds an event. ([\#13346](https://github.com/matrix-org/synapse/issues/13346))
- Faster room joins: avoid blocking when pulling events with partially missing prev events. ([\#13355](https://github.com/matrix-org/synapse/issues/13355))
- Instrument `/messages` for understandable traces in Jaeger. ([\#13368](https://github.com/matrix-org/synapse/issues/13368))
- Remove an unused argument to `get_relations_for_event`. ([\#13383](https://github.com/matrix-org/synapse/issues/13383))
- Add a `merge-back` command to the release script, which automates merging the correct branches after a release. ([\#13393](https://github.com/matrix-org/synapse/issues/13393))
- Adding missing type hints to tests. ([\#13397](https://github.com/matrix-org/synapse/issues/13397))
- Faster Room Joins: don't leave a stuck room partial state flag if the join fails. ([\#13403](https://github.com/matrix-org/synapse/issues/13403))
- Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead. ([\#13404](https://github.com/matrix-org/synapse/issues/13404), [\#13431](https://github.com/matrix-org/synapse/issues/13431))
- Faster Room Joins: prevent Synapse from answering federated join requests for a room which it has not fully joined yet. ([\#13416](https://github.com/matrix-org/synapse/issues/13416))
- Re-enable running Complement tests against Synapse with workers. ([\#13420](https://github.com/matrix-org/synapse/issues/13420))
- Prevent unnecessary lookups to any external `get_event` cache. Contributed by Nick @ Beeper (@fizzadar). ([\#13435](https://github.com/matrix-org/synapse/issues/13435))
- Add some tracing to give more insight into local room joins. ([\#13439](https://github.com/matrix-org/synapse/issues/13439))
- Rename class `RateLimitConfig` to `RatelimitSettings` and `FederationRateLimitConfig` to `FederationRatelimitSettings`. ([\#13442](https://github.com/matrix-org/synapse/issues/13442))
- Add some comments about how event push actions are stored. ([\#13445](https://github.com/matrix-org/synapse/issues/13445), [\#13455](https://github.com/matrix-org/synapse/issues/13455))
- Improve rebuild speed for the "synapse-workers" docker image. ([\#13447](https://github.com/matrix-org/synapse/issues/13447))
- Fix `@tag_args` being off-by-one with the arguments when tagging a span (tracing). ([\#13452](https://github.com/matrix-org/synapse/issues/13452))
- Update type of `EventContext.rejected`. ([\#13460](https://github.com/matrix-org/synapse/issues/13460))
- Use literals in place of `HTTPStatus` constants in tests. ([\#13463](https://github.com/matrix-org/synapse/issues/13463), [\#13469](https://github.com/matrix-org/synapse/issues/13469))
- Correct a misnamed argument in state res v2 internals. ([\#13467](https://github.com/matrix-org/synapse/issues/13467))
Synapse 1.64.0 (2022-08-02) Synapse 1.64.0 (2022-08-02)
=========================== ===========================

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.65.0~rc1) stable; urgency=medium
* New Synapse release 1.65.0rc1.
-- Synapse Packaging team <packages@matrix.org> Tue, 09 Aug 2022 11:39:29 +0100
matrix-synapse-py3 (1.64.0) stable; urgency=medium matrix-synapse-py3 (1.64.0) stable; urgency=medium
* New Synapse release 1.64.0. * New Synapse release 1.64.0.

View file

@ -40,7 +40,8 @@ FROM docker.io/python:${PYTHON_VERSION}-slim as requirements
RUN \ RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq git \ apt-get update -qq && apt-get install -yqq \
build-essential cargo git libffi-dev libssl-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# We install poetry in its own build stage to avoid its dependencies conflicting with # We install poetry in its own build stage to avoid its dependencies conflicting with
@ -68,7 +69,18 @@ COPY pyproject.toml poetry.lock /synapse/
# reason, such as when a git repository is used directly as a dependency. # reason, such as when a git repository is used directly as a dependency.
ARG TEST_ONLY_SKIP_DEP_HASH_VERIFICATION ARG TEST_ONLY_SKIP_DEP_HASH_VERIFICATION
RUN /root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes} # If specified, we won't use the Poetry lockfile.
# Instead, we'll just install what a regular `pip install` would from PyPI.
ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
# Export the dependencies, but only if we're actually going to use the Poetry lockfile.
# Otherwise, just create an empty requirements file so that the Dockerfile can
# proceed.
RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
/root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
else \
touch /synapse/requirements.txt; \
fi
### ###
### Stage 1: builder ### Stage 1: builder
@ -108,8 +120,17 @@ COPY synapse /synapse/synapse/
# ... and what we need to `pip install`. # ... and what we need to `pip install`.
COPY pyproject.toml README.rst /synapse/ COPY pyproject.toml README.rst /synapse/
# Repeat of earlier build argument declaration, as this is a new build stage.
ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
# Install the synapse package itself. # Install the synapse package itself.
RUN pip install --prefix="/install" --no-deps --no-warn-script-location /synapse # If we have populated requirements.txt, we don't install any dependencies
# as we should already have those from the previous `pip install` step.
RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
else \
pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
fi
### ###
### Stage 2: runtime ### Stage 2: runtime

View file

@ -1,39 +1,62 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Inherit from the official Synapse docker image
ARG SYNAPSE_VERSION=latest ARG SYNAPSE_VERSION=latest
# first of all, we create a base image with an nginx which we can copy into the
# target image. For repeated rebuilds, this is much faster than apt installing
# each time.
FROM debian:bullseye-slim AS deps_base
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && \
DEBIAN_FRONTEND=noninteractive apt-get install -yqq --no-install-recommends \
redis-server nginx-light
# Similarly, a base to copy the redis server from.
#
# The redis docker image has fewer dynamic libraries than the debian package,
# which makes it much easier to copy (but we need to make sure we use an image
# based on the same debian version as the synapse image, to make sure we get
# the expected version of libc.
FROM redis:6-bullseye AS redis_base
# now build the final image, based on the the regular Synapse docker image
FROM matrixdotorg/synapse:$SYNAPSE_VERSION FROM matrixdotorg/synapse:$SYNAPSE_VERSION
# Install deps # Install supervisord with pip instead of apt, to avoid installing a second
RUN \ # copy of python.
--mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ pip install supervisor~=4.2
apt-get update -qq && \ RUN mkdir -p /etc/supervisor/conf.d
DEBIAN_FRONTEND=noninteractive apt-get install -yqq --no-install-recommends \
redis-server nginx-light
# Install supervisord with pip instead of apt, to avoid installing a second # Copy over redis and nginx
# copy of python. COPY --from=redis_base /usr/local/bin/redis-server /usr/local/bin
RUN --mount=type=cache,target=/root/.cache/pip \
pip install supervisor~=4.2
# Disable the default nginx sites COPY --from=deps_base /usr/sbin/nginx /usr/sbin
RUN rm /etc/nginx/sites-enabled/default COPY --from=deps_base /usr/share/nginx /usr/share/nginx
COPY --from=deps_base /usr/lib/nginx /usr/lib/nginx
COPY --from=deps_base /etc/nginx /etc/nginx
RUN rm /etc/nginx/sites-enabled/default
RUN mkdir /var/log/nginx /var/lib/nginx
RUN chown www-data /var/log/nginx /var/lib/nginx
# Copy Synapse worker, nginx and supervisord configuration template files # Copy Synapse worker, nginx and supervisord configuration template files
COPY ./docker/conf-workers/* /conf/ COPY ./docker/conf-workers/* /conf/
# Copy a script to prefix log lines with the supervisor program name # Copy a script to prefix log lines with the supervisor program name
COPY ./docker/prefix-log /usr/local/bin/ COPY ./docker/prefix-log /usr/local/bin/
# Expose nginx listener port # Expose nginx listener port
EXPOSE 8080/tcp EXPOSE 8080/tcp
# A script to read environment variables and create the necessary # A script to read environment variables and create the necessary
# files to run the desired worker configuration. Will start supervisord. # files to run the desired worker configuration. Will start supervisord.
COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py
ENTRYPOINT ["/configure_workers_and_start.py"] ENTRYPOINT ["/configure_workers_and_start.py"]
# Replace the healthcheck with one which checks *all* the workers. The script # Replace the healthcheck with one which checks *all* the workers. The script
# is generated by configure_workers_and_start.py. # is generated by configure_workers_and_start.py.
HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \ HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
CMD /bin/sh /healthcheck.sh CMD /bin/sh /healthcheck.sh

View file

@ -19,7 +19,7 @@ username=www-data
autorestart=true autorestart=true
[program:redis] [program:redis]
command=/usr/local/bin/prefix-log /usr/bin/redis-server /etc/redis/redis.conf --daemonize no command=/usr/local/bin/prefix-log /usr/local/bin/redis-server
priority=1 priority=1
stdout_logfile=/dev/stdout stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0 stdout_logfile_maxbytes=0

View file

@ -263,7 +263,7 @@ class MyAuthProvider:
return None return None
if self.credentials.get(username) == login_dict.get("my_field"): if self.credentials.get(username) == login_dict.get("my_field"):
return self.api.get_qualified_user_id(username) return (self.api.get_qualified_user_id(username), None)
async def check_pass( async def check_pass(
self, self,
@ -280,5 +280,5 @@ class MyAuthProvider:
return None return None
if self.credentials.get(username) == login_dict.get("password"): if self.credentials.get(username) == login_dict.get("password"):
return self.api.get_qualified_user_id(username) return (self.api.get_qualified_user_id(username), None)
``` ```

View file

@ -22,7 +22,7 @@ choose their own username.
In the first case - where users are automatically allocated a Matrix ID - it is In the first case - where users are automatically allocated a Matrix ID - it is
the responsibility of the mapping provider to normalise the SSO attributes and the responsibility of the mapping provider to normalise the SSO attributes and
map them to a valid Matrix ID. The [specification for Matrix map them to a valid Matrix ID. The [specification for Matrix
IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some IDs](https://spec.matrix.org/latest/appendices/#user-identifiers) has some
information about what is considered valid. information about what is considered valid.
If the mapping provider does not assign a Matrix ID, then Synapse will If the mapping provider does not assign a Matrix ID, then Synapse will
@ -37,9 +37,10 @@ as Synapse). The Synapse config is then modified to point to the mapping provide
## OpenID Mapping Providers ## OpenID Mapping Providers
The OpenID mapping provider can be customized by editing the The OpenID mapping provider can be customized by editing the
`oidc_config.user_mapping_provider.module` config option. [`oidc_providers.user_mapping_provider.module`](usage/configuration/config_documentation.md#oidc_providers)
config option.
`oidc_config.user_mapping_provider.config` allows you to provide custom `oidc_providers.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should user mapping provider built in to Synapse. If using a custom module, you should
@ -58,7 +59,7 @@ A custom mapping provider must specify the following methods:
- This method should have the `@staticmethod` decoration. - This method should have the `@staticmethod` decoration.
- Arguments: - Arguments:
- `config` - A `dict` representing the parsed content of the - `config` - A `dict` representing the parsed content of the
`oidc_config.user_mapping_provider.config` homeserver config option. `oidc_providers.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract and validate Runs on homeserver startup. Providers should extract and validate
any option values they need here. any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's - Whatever is returned will be passed back to the user mapping provider module's
@ -102,7 +103,7 @@ A custom mapping provider must specify the following methods:
will be returned as part of the response during a successful login. will be returned as part of the response during a successful login.
Note that care should be taken to not overwrite any of the parameters Note that care should be taken to not overwrite any of the parameters
usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). usually returned as part of the [login response](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3login).
### Default OpenID Mapping Provider ### Default OpenID Mapping Provider
@ -113,7 +114,8 @@ specified in the config. It is located at
## SAML Mapping Providers ## SAML Mapping Providers
The SAML mapping provider can be customized by editing the The SAML mapping provider can be customized by editing the
`saml2_config.user_mapping_provider.module` config option. [`saml2_config.user_mapping_provider.module`](docs/usage/configuration/config_documentation.md#saml2_config)
config option.
`saml2_config.user_mapping_provider.config` allows you to provide custom `saml2_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for configuration options to the module. Check with the module's documentation for

View file

@ -5,8 +5,9 @@
Many of the API calls in the admin api will require an `access_token` for a Many of the API calls in the admin api will require an `access_token` for a
server admin. (Note that a server admin is distinct from a room admin.) server admin. (Note that a server admin is distinct from a room admin.)
A user can be marked as a server admin by updating the database directly, e.g.: An existing user can be marked as a server admin by updating the database directly.
Check your [database settings](config_documentation.md#database) in the configuration file, connect to the correct database using either `psql [database name]` (if using PostgreSQL) or `sqlite3 path/to/your/database.db` (if using SQLite) and elevate the user `@foo:bar.com` to administrator.
```sql ```sql
UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; UPDATE users SET admin = 1 WHERE name = '@foo:bar.com';
``` ```

View file

@ -2,11 +2,11 @@
This API allows you to manage tokens which can be used to authenticate This API allows you to manage tokens which can be used to authenticate
registration requests, as proposed in registration requests, as proposed in
[MSC3231](https://github.com/matrix-org/matrix-doc/blob/main/proposals/3231-token-authenticated-registration.md). [MSC3231](https://github.com/matrix-org/matrix-doc/blob/main/proposals/3231-token-authenticated-registration.md)
and stabilised in version 1.2 of the Matrix specification.
To use it, you will need to enable the `registration_requires_token` config To use it, you will need to enable the `registration_requires_token` config
option, and authenticate by providing an `access_token` for a server admin: option, and authenticate by providing an `access_token` for a server admin:
see [Admin API](../../usage/administration/admin_api). see [Admin API](../admin_api).
Note that this API is still experimental; not all clients may support it yet.
## Registration token objects ## Registration token objects

View file

@ -72,49 +72,6 @@ apply if you want your config file to be read properly. A few helpful things to
In addition, each setting has an example of its usage, with the proper indentation In addition, each setting has an example of its usage, with the proper indentation
shown. shown.
## Contents
[Modules](#modules)
[Server](#server)
[Homeserver Blocking](#homeserver-blocking)
[TLS](#tls)
[Federation](#federation)
[Caching](#caching)
[Database](#database)
[Logging](#logging)
[Ratelimiting](#ratelimiting)
[Media Store](#media-store)
[Captcha](#captcha)
[TURN](#turn)
[Registration](#registration)
[API Configuration](#api-configuration)
[Signing Keys](#signing-keys)
[Single Sign On Integration](#single-sign-on-integration)
[Push](#push)
[Rooms](#rooms)
[Opentracing](#opentracing)
[Workers](#workers)
[Background Updates](#background-updates)
## Modules ## Modules
Server admins can expand Synapse's functionality with external modules. Server admins can expand Synapse's functionality with external modules.
@ -486,7 +443,8 @@ Sub-options for each listener include:
* `names`: a list of names of HTTP resources. See below for a list of valid resource names. * `names`: a list of names of HTTP resources. See below for a list of valid resource names.
* `compress`: set to true to enable HTTP compression for this resource. * `compress`: set to true to enable gzip compression on HTTP bodies for this resource. This is currently only supported with the
`client`, `consent` and `metrics` resources.
* `additional_resources`: Only valid for an 'http' listener. A map of * `additional_resources`: Only valid for an 'http' listener. A map of
additional endpoints which should be loaded via dynamic modules. additional endpoints which should be loaded via dynamic modules.
@ -1098,26 +1056,26 @@ allow_device_name_lookup_over_federation: true
--- ---
## Caching ## ## Caching ##
Options related to caching Options related to caching.
--- ---
### `event_cache_size` ### `event_cache_size`
The number of events to cache in memory. Not affected by The number of events to cache in memory. Not affected by
`caches.global_factor`. Defaults to 10K. `caches.global_factor` and is not part of the `caches` section. Defaults to 10K.
Example configuration: Example configuration:
```yaml ```yaml
event_cache_size: 15K event_cache_size: 15K
``` ```
--- ---
### `cache` and associated values ### `caches` and associated values
A cache 'factor' is a multiplier that can be applied to each of A cache 'factor' is a multiplier that can be applied to each of
Synapse's caches in order to increase or decrease the maximum Synapse's caches in order to increase or decrease the maximum
number of entries that can be stored. number of entries that can be stored.
Caching can be configured through the following sub-options: `caches` can be configured through the following sub-options:
* `global_factor`: Controls the global cache factor, which is the default cache factor * `global_factor`: Controls the global cache factor, which is the default cache factor
for all caches if a specific factor for that cache is not otherwise for all caches if a specific factor for that cache is not otherwise
@ -1179,6 +1137,7 @@ Caching can be configured through the following sub-options:
Example configuration: Example configuration:
```yaml ```yaml
event_cache_size: 15K
caches: caches:
global_factor: 1.0 global_factor: 1.0
per_cache_factors: per_cache_factors:
@ -1858,7 +1817,7 @@ Example configuration:
max_spider_size: 8M max_spider_size: 8M
``` ```
--- ---
### `url_preview_language` ### `url_preview_accept_language`
A list of values for the Accept-Language HTTP header used when A list of values for the Accept-Language HTTP header used when
downloading webpages during URL preview generation. This allows downloading webpages during URL preview generation. This allows
@ -2537,9 +2496,13 @@ track_appservice_user_ips: true
--- ---
### `macaroon_secret_key` ### `macaroon_secret_key`
A secret which is used to sign access tokens. If none is specified, A secret which is used to sign
the `registration_shared_secret` is used, if one is given; otherwise, - access token for guest users,
a secret key is derived from the signing key. - short-term login token used during SSO logins (OIDC or SAML2) and
- token used for unsubscribing from email notifications.
If none is specified, the `registration_shared_secret` is used, if one is given;
otherwise, a secret key is derived from the signing key.
Example configuration: Example configuration:
```yaml ```yaml

24
poetry.lock generated
View file

@ -177,7 +177,7 @@ optional = false
python-versions = "*" python-versions = "*"
[package.extras] [package.extras]
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] test = ["hypothesis (==3.55.3)", "flake8 (==3.7.8)"]
[[package]] [[package]]
name = "constantly" name = "constantly"
@ -435,8 +435,8 @@ optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.extras] [package.extras]
test = ["pytest", "pytest-trio", "pytest-asyncio", "testpath", "trio", "async-timeout"] trio = ["async-generator", "trio"]
trio = ["trio", "async-generator"] test = ["async-timeout", "trio", "testpath", "pytest-asyncio", "pytest-trio", "pytest"]
[[package]] [[package]]
name = "jinja2" name = "jinja2"
@ -535,12 +535,12 @@ attrs = "*"
importlib-metadata = {version = ">=1.4", markers = "python_version < \"3.8\""} importlib-metadata = {version = ">=1.4", markers = "python_version < \"3.8\""}
[package.extras] [package.extras]
dev = ["tox", "twisted", "aiounittest", "mypy (==0.910)", "black (==22.3.0)", "flake8 (==4.0.1)", "isort (==5.9.3)", "build (==0.8.0)", "twine (==4.0.1)"] test = ["aiounittest", "twisted", "tox"]
test = ["tox", "twisted", "aiounittest"] dev = ["twine (==4.0.1)", "build (==0.8.0)", "isort (==5.9.3)", "flake8 (==4.0.1)", "black (==22.3.0)", "mypy (==0.910)", "aiounittest", "twisted", "tox"]
[[package]] [[package]]
name = "matrix-synapse-ldap3" name = "matrix-synapse-ldap3"
version = "0.2.1" version = "0.2.2"
description = "An LDAP3 auth provider for Synapse" description = "An LDAP3 auth provider for Synapse"
category = "main" category = "main"
optional = true optional = true
@ -552,7 +552,7 @@ service-identity = "*"
Twisted = ">=15.1.0" Twisted = ">=15.1.0"
[package.extras] [package.extras]
dev = ["matrix-synapse", "tox", "ldaptor", "mypy (==0.910)", "types-setuptools", "black (==22.3.0)", "flake8 (==4.0.1)", "isort (==5.9.3)"] dev = ["isort (==5.9.3)", "flake8 (==4.0.1)", "black (==22.3.0)", "types-setuptools", "mypy (==0.910)", "ldaptor", "tox", "matrix-synapse"]
[[package]] [[package]]
name = "mccabe" name = "mccabe"
@ -820,10 +820,10 @@ optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.extras] [package.extras]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
docs = ["zope.interface", "sphinx-rtd-theme", "sphinx"]
dev = ["pre-commit", "mypy", "coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)", "cryptography (>=3.3.1)", "zope.interface", "sphinx-rtd-theme", "sphinx"]
crypto = ["cryptography (>=3.3.1)"] crypto = ["cryptography (>=3.3.1)"]
dev = ["sphinx", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)", "pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)", "mypy", "pre-commit"]
docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"]
tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"]
[[package]] [[package]]
name = "pymacaroons" name = "pymacaroons"
@ -2055,8 +2055,8 @@ matrix-common = [
{file = "matrix_common-1.2.1.tar.gz", hash = "sha256:a99dcf02a6bd95b24a5a61b354888a2ac92bf2b4b839c727b8dd9da2cdfa3853"}, {file = "matrix_common-1.2.1.tar.gz", hash = "sha256:a99dcf02a6bd95b24a5a61b354888a2ac92bf2b4b839c727b8dd9da2cdfa3853"},
] ]
matrix-synapse-ldap3 = [ matrix-synapse-ldap3 = [
{file = "matrix-synapse-ldap3-0.2.1.tar.gz", hash = "sha256:bfb4390f4a262ffb0d6f057ff3aeb1e46d4e52ff420a064d795fb4f555f00285"}, {file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"},
{file = "matrix_synapse_ldap3-0.2.1-py3-none-any.whl", hash = "sha256:1b3310a60f1d06466f35905a269b6df95747fd1305f2b7fe638f373963b2aa2c"}, {file = "matrix_synapse_ldap3-0.2.2-py3-none-any.whl", hash = "sha256:66ee4c85d7952c6c27fd04c09cdfdf4847b8e8b7d6a7ada6ba1100013bda060f"},
] ]
mccabe = [ mccabe = [
{file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},

View file

@ -54,7 +54,7 @@ skip_gitignore = true
[tool.poetry] [tool.poetry]
name = "matrix-synapse" name = "matrix-synapse"
version = "1.64.0" version = "1.65.0rc1"
description = "Homeserver for the Matrix decentralised comms protocol" description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"] authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -101,6 +101,7 @@ if [ -z "$skip_docker_build" ]; then
echo_if_github "::group::Build Docker image: matrixdotorg/synapse" echo_if_github "::group::Build Docker image: matrixdotorg/synapse"
docker build -t matrixdotorg/synapse \ docker build -t matrixdotorg/synapse \
--build-arg TEST_ONLY_SKIP_DEP_HASH_VERIFICATION \ --build-arg TEST_ONLY_SKIP_DEP_HASH_VERIFICATION \
--build-arg TEST_ONLY_IGNORE_POETRY_LOCKFILE \
-f "docker/Dockerfile" . -f "docker/Dockerfile" .
echo_if_github "::endgroup::" echo_if_github "::endgroup::"

View file

@ -32,6 +32,7 @@ import click
import commonmark import commonmark
import git import git
from click.exceptions import ClickException from click.exceptions import ClickException
from git import GitCommandError, Repo
from github import Github from github import Github
from packaging import version from packaging import version
@ -55,9 +56,12 @@ def run_until_successful(
def cli() -> None: def cli() -> None:
"""An interactive script to walk through the parts of creating a release. """An interactive script to walk through the parts of creating a release.
Requires the dev dependencies be installed, which can be done via: Requirements:
- The dev dependencies be installed, which can be done via:
pip install -e .[dev] pip install -e .[dev]
- A checkout of the sytest repository at ../sytest
Then to use: Then to use:
@ -75,6 +79,8 @@ def cli() -> None:
# Optional: generate some nice links for the announcement # Optional: generate some nice links for the announcement
./scripts-dev/release.py merge-back
./scripts-dev/release.py announce ./scripts-dev/release.py announce
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
@ -89,10 +95,12 @@ def prepare() -> None:
""" """
# Make sure we're in a git repo. # Make sure we're in a git repo.
repo = get_repo_and_check_clean_checkout() synapse_repo = get_repo_and_check_clean_checkout()
sytest_repo = get_repo_and_check_clean_checkout("../sytest", "sytest")
click.secho("Updating git repo...") click.secho("Updating Synapse and Sytest git repos...")
repo.remote().fetch() synapse_repo.remote().fetch()
sytest_repo.remote().fetch()
# Get the current version and AST from root Synapse module. # Get the current version and AST from root Synapse module.
current_version = get_package_version() current_version = get_package_version()
@ -166,12 +174,12 @@ def prepare() -> None:
assert not parsed_new_version.is_postrelease assert not parsed_new_version.is_postrelease
release_branch_name = get_release_branch_name(parsed_new_version) release_branch_name = get_release_branch_name(parsed_new_version)
release_branch = find_ref(repo, release_branch_name) release_branch = find_ref(synapse_repo, release_branch_name)
if release_branch: if release_branch:
if release_branch.is_remote(): if release_branch.is_remote():
# If the release branch only exists on the remote we check it out # If the release branch only exists on the remote we check it out
# locally. # locally.
repo.git.checkout(release_branch_name) synapse_repo.git.checkout(release_branch_name)
else: else:
# If a branch doesn't exist we create one. We ask which one branch it # If a branch doesn't exist we create one. We ask which one branch it
# should be based off, defaulting to sensible values depending on the # should be based off, defaulting to sensible values depending on the
@ -187,25 +195,34 @@ def prepare() -> None:
"Which branch should the release be based on?", default=default "Which branch should the release be based on?", default=default
) )
base_branch = find_ref(repo, branch_name) for repo_name, repo in {"synapse": synapse_repo, "sytest": sytest_repo}.items():
if not base_branch: base_branch = find_ref(repo, branch_name)
print(f"Could not find base branch {branch_name}!") if not base_branch:
click.get_current_context().abort() print(f"Could not find base branch {branch_name} for {repo_name}!")
click.get_current_context().abort()
# Check out the base branch and ensure it's up to date # Check out the base branch and ensure it's up to date
repo.head.set_reference(base_branch, "check out the base branch") repo.head.set_reference(
repo.head.reset(index=True, working_tree=True) base_branch, f"check out the base branch for {repo_name}"
if not base_branch.is_remote(): )
update_branch(repo) repo.head.reset(index=True, working_tree=True)
if not base_branch.is_remote():
update_branch(repo)
# Create the new release branch # Create the new release branch
# Type ignore will no longer be needed after GitPython 3.1.28. # Type ignore will no longer be needed after GitPython 3.1.28.
# See https://github.com/gitpython-developers/GitPython/pull/1419 # See https://github.com/gitpython-developers/GitPython/pull/1419
repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type]
# Special-case SyTest: we don't actually prepare any files so we may
# as well push it now (and only when we create a release branch;
# not on subsequent RCs or full releases).
if click.confirm("Push new SyTest branch?", default=True):
sytest_repo.git.push("-u", sytest_repo.remote().name, release_branch_name)
# Switch to the release branch and ensure it's up to date. # Switch to the release branch and ensure it's up to date.
repo.git.checkout(release_branch_name) synapse_repo.git.checkout(release_branch_name)
update_branch(repo) update_branch(synapse_repo)
# Update the version specified in pyproject.toml. # Update the version specified in pyproject.toml.
subprocess.check_output(["poetry", "version", new_version]) subprocess.check_output(["poetry", "version", new_version])
@ -230,15 +247,15 @@ def prepare() -> None:
run_until_successful('dch -M -r -D stable ""', shell=True) run_until_successful('dch -M -r -D stable ""', shell=True)
# Show the user the changes and ask if they want to edit the change log. # Show the user the changes and ask if they want to edit the change log.
repo.git.add("-u") synapse_repo.git.add("-u")
subprocess.run("git diff --cached", shell=True) subprocess.run("git diff --cached", shell=True)
if click.confirm("Edit changelog?", default=False): if click.confirm("Edit changelog?", default=False):
click.edit(filename="CHANGES.md") click.edit(filename="CHANGES.md")
# Commit the changes. # Commit the changes.
repo.git.add("-u") synapse_repo.git.add("-u")
repo.git.commit("-m", new_version) synapse_repo.git.commit("-m", new_version)
# We give the option to bail here in case the user wants to make sure things # We give the option to bail here in case the user wants to make sure things
# are OK before pushing. # are OK before pushing.
@ -246,17 +263,21 @@ def prepare() -> None:
print("") print("")
print("Run when ready to push:") print("Run when ready to push:")
print("") print("")
print(f"\tgit push -u {repo.remote().name} {repo.active_branch.name}") print(
f"\tgit push -u {synapse_repo.remote().name} {synapse_repo.active_branch.name}"
)
print("") print("")
sys.exit(0) sys.exit(0)
# Otherwise, push and open the changelog in the browser. # Otherwise, push and open the changelog in the browser.
repo.git.push("-u", repo.remote().name, repo.active_branch.name) synapse_repo.git.push(
"-u", synapse_repo.remote().name, synapse_repo.active_branch.name
)
print("Opening the changelog in your browser...") print("Opening the changelog in your browser...")
print("Please ask others to give it a check.") print("Please ask others to give it a check.")
click.launch( click.launch(
f"https://github.com/matrix-org/synapse/blob/{repo.active_branch.name}/CHANGES.md" f"https://github.com/matrix-org/synapse/blob/{synapse_repo.active_branch.name}/CHANGES.md"
) )
@ -423,6 +444,79 @@ def upload() -> None:
) )
def _merge_into(repo: Repo, source: str, target: str) -> None:
"""
Merges branch `source` into branch `target`.
Pulls both before merging and pushes the result.
"""
# Update our branches and switch to the target branch
for branch in [source, target]:
click.echo(f"Switching to {branch} and pulling...")
repo.heads[branch].checkout()
# Pull so we're up to date
repo.remote().pull()
assert repo.active_branch.name == target
try:
# TODO This seemed easier than using GitPython directly
click.echo(f"Merging {source}...")
repo.git.merge(source)
except GitCommandError as exc:
# If a merge conflict occurs, give some context and try to
# make it easy to abort if necessary.
click.echo(exc)
if not click.confirm(
f"Likely merge conflict whilst merging ({source}{target}). "
f"Have you resolved it?"
):
repo.git.merge("--abort")
return
# Push result.
click.echo("Pushing...")
repo.remote().push()
@cli.command()
def merge_back() -> None:
"""Merge the release branch back into the appropriate branches.
All branches will be automatically pulled from the remote and the results
will be pushed to the remote."""
synapse_repo = get_repo_and_check_clean_checkout()
branch_name = synapse_repo.active_branch.name
if not branch_name.startswith("release-v"):
raise RuntimeError("Not on a release branch. This does not seem sensible.")
# Pull so we're up to date
synapse_repo.remote().pull()
current_version = get_package_version()
if current_version.is_prerelease:
# Release candidate
if click.confirm(f"Merge {branch_name} → develop?", default=True):
_merge_into(synapse_repo, branch_name, "develop")
else:
# Full release
sytest_repo = get_repo_and_check_clean_checkout("../sytest", "sytest")
if click.confirm(f"Merge {branch_name} → master?", default=True):
_merge_into(synapse_repo, branch_name, "master")
if click.confirm("Merge master → develop?", default=True):
_merge_into(synapse_repo, "master", "develop")
if click.confirm(f"On SyTest, merge {branch_name} → master?", default=True):
_merge_into(sytest_repo, branch_name, "master")
if click.confirm("On SyTest, merge master → develop?", default=True):
_merge_into(sytest_repo, "master", "develop")
@cli.command() @cli.command()
def announce() -> None: def announce() -> None:
"""Generate markdown to announce the release.""" """Generate markdown to announce the release."""
@ -469,14 +563,18 @@ def get_release_branch_name(version_number: version.Version) -> str:
return f"release-v{version_number.major}.{version_number.minor}" return f"release-v{version_number.major}.{version_number.minor}"
def get_repo_and_check_clean_checkout() -> git.Repo: def get_repo_and_check_clean_checkout(
path: str = ".", name: str = "synapse"
) -> git.Repo:
"""Get the project repo and check it's not got any uncommitted changes.""" """Get the project repo and check it's not got any uncommitted changes."""
try: try:
repo = git.Repo() repo = git.Repo(path=path)
except git.InvalidGitRepositoryError: except git.InvalidGitRepositoryError:
raise click.ClickException("Not in Synapse repo.") raise click.ClickException(
f"{path} is not a git repository (expecting a {name} repository)."
)
if repo.is_dirty(): if repo.is_dirty():
raise click.ClickException("Uncommitted changes exist.") raise click.ClickException(f"Uncommitted changes exist in {path}.")
return repo return repo

View file

@ -26,11 +26,17 @@ from synapse.api.errors import (
Codes, Codes,
InvalidClientTokenError, InvalidClientTokenError,
MissingClientTokenError, MissingClientTokenError,
UnstableSpecAuthError,
) )
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.logging.opentracing import (
active_span,
force_tracing,
start_active_span,
trace,
)
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
@ -106,8 +112,11 @@ class Auth:
forgot = await self.store.did_forget(user_id, room_id) forgot = await self.store.did_forget(user_id, room_id)
if not forgot: if not forgot:
return membership, member_event_id return membership, member_event_id
raise UnstableSpecAuthError(
raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) 403,
"User %s not in room %s" % (user_id, room_id),
errcode=Codes.NOT_JOINED,
)
async def get_user_by_req( async def get_user_by_req(
self, self,
@ -564,6 +573,7 @@ class Auth:
return query_params[0].decode("ascii") return query_params[0].decode("ascii")
@trace
async def check_user_in_room_or_world_readable( async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False self, room_id: str, user_id: str, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]: ) -> Tuple[str, Optional[str]]:
@ -601,8 +611,9 @@ class Auth:
== HistoryVisibility.WORLD_READABLE == HistoryVisibility.WORLD_READABLE
): ):
return Membership.JOIN, None return Membership.JOIN, None
raise AuthError( raise UnstableSpecAuthError(
403, 403,
"User %s not in room %s, and room previews are disabled" "User %s not in room %s, and room previews are disabled"
% (user_id, room_id), % (user_id, room_id),
errcode=Codes.NOT_JOINED,
) )

View file

@ -257,7 +257,8 @@ class GuestAccess:
class ReceiptTypes: class ReceiptTypes:
READ: Final = "m.read" READ: Final = "m.read"
READ_PRIVATE: Final = "org.matrix.msc2285.read.private" READ_PRIVATE: Final = "m.read.private"
UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
FULLY_READ: Final = "m.fully_read" FULLY_READ: Final = "m.fully_read"
@ -268,4 +269,4 @@ class PublicRoomsFilterFields:
""" """
GENERIC_SEARCH_TERM: Final = "generic_search_term" GENERIC_SEARCH_TERM: Final = "generic_search_term"
ROOM_TYPES: Final = "org.matrix.msc3827.room_types" ROOM_TYPES: Final = "room_types"

View file

@ -26,6 +26,7 @@ from twisted.web import http
from synapse.util import json_decoder from synapse.util import json_decoder
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
from synapse.types import JsonDict from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,6 +81,12 @@ class Codes(str, Enum):
INVALID_SIGNATURE = "M_INVALID_SIGNATURE" INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED" USER_DEACTIVATED = "M_USER_DEACTIVATED"
# Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
ALREADY_JOINED = "ORG.MATRIX.MSC3848.ALREADY_JOINED"
NOT_JOINED = "ORG.MATRIX.MSC3848.NOT_JOINED"
INSUFFICIENT_POWER = "ORG.MATRIX.MSC3848.INSUFFICIENT_POWER"
# The account has been suspended on the server. # The account has been suspended on the server.
# By opposition to `USER_DEACTIVATED`, this is a reversible measure # By opposition to `USER_DEACTIVATED`, this is a reversible measure
# that can possibly be appealed and reverted. # that can possibly be appealed and reverted.
@ -167,7 +174,7 @@ class SynapseError(CodeMessageException):
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields) return cs_error(self.msg, self.errcode, **self._additional_fields)
@ -213,7 +220,7 @@ class ConsentNotGivenError(SynapseError):
) )
self._consent_uri = consent_uri self._consent_uri = consent_uri
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
@ -307,6 +314,37 @@ class AuthError(SynapseError):
super().__init__(code, msg, errcode, additional_fields) super().__init__(code, msg, errcode, additional_fields)
class UnstableSpecAuthError(AuthError):
"""An error raised when a new error code is being proposed to replace a previous one.
This error will return a "org.matrix.unstable.errcode" property with the new error code,
with the previous error code still being defined in the "errcode" property.
This error will include `org.matrix.msc3848.unstable.errcode` in the C-S error body.
"""
def __init__(
self,
code: int,
msg: str,
errcode: str,
previous_errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None,
):
self.previous_errcode = previous_errcode
super().__init__(code, msg, errcode, additional_fields)
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
fields = {}
if config is not None and config.experimental.msc3848_enabled:
fields["org.matrix.msc3848.unstable.errcode"] = self.errcode
return cs_error(
self.msg,
self.previous_errcode,
**fields,
**self._additional_fields,
)
class InvalidClientCredentialsError(SynapseError): class InvalidClientCredentialsError(SynapseError):
"""An error raised when there was a problem with the authorisation credentials """An error raised when there was a problem with the authorisation credentials
in a client request. in a client request.
@ -338,8 +376,8 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout self._soft_logout = soft_logout
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
d = super().error_dict() d = super().error_dict(config)
d["soft_logout"] = self._soft_logout d["soft_logout"] = self._soft_logout
return d return d
@ -362,7 +400,7 @@ class ResourceLimitError(SynapseError):
self.limit_type = limit_type self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode) super().__init__(code, msg, errcode=errcode)
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error( return cs_error(
self.msg, self.msg,
self.errcode, self.errcode,
@ -397,7 +435,7 @@ class InvalidCaptchaError(SynapseError):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
self.error_url = error_url self.error_url = error_url
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url) return cs_error(self.msg, self.errcode, error_url=self.error_url)
@ -414,7 +452,7 @@ class LimitExceededError(SynapseError):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@ -429,7 +467,7 @@ class RoomKeysVersionError(SynapseError):
super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION) super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
self.current_version = current_version self.current_version = current_version
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, current_version=self.current_version) return cs_error(self.msg, self.errcode, current_version=self.current_version)
@ -469,7 +507,7 @@ class IncompatibleRoomVersionError(SynapseError):
self._room_version = room_version self._room_version = room_version
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version) return cs_error(self.msg, self.errcode, room_version=self._room_version)
@ -515,7 +553,7 @@ class UnredactedContentDeletedError(SynapseError):
) )
self.content_keep_ms = content_keep_ms self.content_keep_ms = content_keep_ms
def error_dict(self) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
extra = {} extra = {}
if self.content_keep_ms is not None: if self.content_keep_ms is not None:
extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms} extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms}

View file

@ -17,7 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RateLimitConfig from synapse.config.ratelimiting import RatelimitSettings
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Requester from synapse.types import Requester
from synapse.util import Clock from synapse.util import Clock
@ -314,8 +314,8 @@ class RequestRatelimiter:
self, self,
store: DataStore, store: DataStore,
clock: Clock, clock: Clock,
rc_message: RateLimitConfig, rc_message: RatelimitSettings,
rc_admin_redaction: Optional[RateLimitConfig], rc_admin_redaction: Optional[RatelimitSettings],
): ):
self.store = store self.store = store
self.clock = clock self.clock = clock

View file

@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
# MSC2716 (importing historical messages) # MSC2716 (importing historical messages)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
# MSC2285 (private read receipts) # MSC2285 (unstable private read receipts)
self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
# MSC3244 (room version capabilities) # MSC3244 (room version capabilities)
@ -88,5 +88,5 @@ class ExperimentalConfig(Config):
# MSC3715: dir param on /relations. # MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
# MSC3827: Filtering of /publicRooms by room type # MSC3848: Introduce errcodes for specific event sending failures
self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False) self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)

View file

@ -21,7 +21,7 @@ from synapse.types import JsonDict
from ._base import Config from ._base import Config
class RateLimitConfig: class RatelimitSettings:
def __init__( def __init__(
self, self,
config: Dict[str, float], config: Dict[str, float],
@ -34,7 +34,7 @@ class RateLimitConfig:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class FederationRateLimitConfig: class FederationRatelimitSettings:
window_size: int = 1000 window_size: int = 1000
sleep_limit: int = 10 sleep_limit: int = 10
sleep_delay: int = 500 sleep_delay: int = 500
@ -50,11 +50,11 @@ class RatelimitConfig(Config):
# Load the new-style messages config if it exists. Otherwise fall back # Load the new-style messages config if it exists. Otherwise fall back
# to the old method. # to the old method.
if "rc_message" in config: if "rc_message" in config:
self.rc_message = RateLimitConfig( self.rc_message = RatelimitSettings(
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
) )
else: else:
self.rc_message = RateLimitConfig( self.rc_message = RatelimitSettings(
{ {
"per_second": config.get("rc_messages_per_second", 0.2), "per_second": config.get("rc_messages_per_second", 0.2),
"burst_count": config.get("rc_message_burst_count", 10.0), "burst_count": config.get("rc_message_burst_count", 10.0),
@ -64,9 +64,9 @@ class RatelimitConfig(Config):
# Load the new-style federation config, if it exists. Otherwise, fall # Load the new-style federation config, if it exists. Otherwise, fall
# back to the old method. # back to the old method.
if "rc_federation" in config: if "rc_federation" in config:
self.rc_federation = FederationRateLimitConfig(**config["rc_federation"]) self.rc_federation = FederationRatelimitSettings(**config["rc_federation"])
else: else:
self.rc_federation = FederationRateLimitConfig( self.rc_federation = FederationRatelimitSettings(
**{ **{
k: v k: v
for k, v in { for k, v in {
@ -80,17 +80,17 @@ class RatelimitConfig(Config):
} }
) )
self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) self.rc_registration = RatelimitSettings(config.get("rc_registration", {}))
self.rc_registration_token_validity = RateLimitConfig( self.rc_registration_token_validity = RatelimitSettings(
config.get("rc_registration_token_validity", {}), config.get("rc_registration_token_validity", {}),
defaults={"per_second": 0.1, "burst_count": 5}, defaults={"per_second": 0.1, "burst_count": 5},
) )
rc_login_config = config.get("rc_login", {}) rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {}))
self.rc_login_failed_attempts = RateLimitConfig( self.rc_login_failed_attempts = RatelimitSettings(
rc_login_config.get("failed_attempts", {}) rc_login_config.get("failed_attempts", {})
) )
@ -101,20 +101,20 @@ class RatelimitConfig(Config):
rc_admin_redaction = config.get("rc_admin_redaction") rc_admin_redaction = config.get("rc_admin_redaction")
self.rc_admin_redaction = None self.rc_admin_redaction = None
if rc_admin_redaction: if rc_admin_redaction:
self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction)
self.rc_joins_local = RateLimitConfig( self.rc_joins_local = RatelimitSettings(
config.get("rc_joins", {}).get("local", {}), config.get("rc_joins", {}).get("local", {}),
defaults={"per_second": 0.1, "burst_count": 10}, defaults={"per_second": 0.1, "burst_count": 10},
) )
self.rc_joins_remote = RateLimitConfig( self.rc_joins_remote = RatelimitSettings(
config.get("rc_joins", {}).get("remote", {}), config.get("rc_joins", {}).get("remote", {}),
defaults={"per_second": 0.01, "burst_count": 10}, defaults={"per_second": 0.01, "burst_count": 10},
) )
# Track the rate of joins to a given room. If there are too many, temporarily # Track the rate of joins to a given room. If there are too many, temporarily
# prevent local joins and remote joins via this server. # prevent local joins and remote joins via this server.
self.rc_joins_per_room = RateLimitConfig( self.rc_joins_per_room = RatelimitSettings(
config.get("rc_joins_per_room", {}), config.get("rc_joins_per_room", {}),
defaults={"per_second": 1, "burst_count": 10}, defaults={"per_second": 1, "burst_count": 10},
) )
@ -124,31 +124,31 @@ class RatelimitConfig(Config):
# * For requests received over federation this is keyed by the origin. # * For requests received over federation this is keyed by the origin.
# #
# Note that this isn't exposed in the configuration as it is obscure. # Note that this isn't exposed in the configuration as it is obscure.
self.rc_key_requests = RateLimitConfig( self.rc_key_requests = RatelimitSettings(
config.get("rc_key_requests", {}), config.get("rc_key_requests", {}),
defaults={"per_second": 20, "burst_count": 100}, defaults={"per_second": 20, "burst_count": 100},
) )
self.rc_3pid_validation = RateLimitConfig( self.rc_3pid_validation = RatelimitSettings(
config.get("rc_3pid_validation") or {}, config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_invites_per_room = RateLimitConfig( self.rc_invites_per_room = RatelimitSettings(
config.get("rc_invites", {}).get("per_room", {}), config.get("rc_invites", {}).get("per_room", {}),
defaults={"per_second": 0.3, "burst_count": 10}, defaults={"per_second": 0.3, "burst_count": 10},
) )
self.rc_invites_per_user = RateLimitConfig( self.rc_invites_per_user = RatelimitSettings(
config.get("rc_invites", {}).get("per_user", {}), config.get("rc_invites", {}).get("per_user", {}),
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_invites_per_issuer = RateLimitConfig( self.rc_invites_per_issuer = RatelimitSettings(
config.get("rc_invites", {}).get("per_issuer", {}), config.get("rc_invites", {}).get("per_issuer", {}),
defaults={"per_second": 0.3, "burst_count": 10}, defaults={"per_second": 0.3, "burst_count": 10},
) )
self.rc_third_party_invite = RateLimitConfig( self.rc_third_party_invite = RatelimitSettings(
config.get("rc_third_party_invite", {}), config.get("rc_third_party_invite", {}),
defaults={ defaults={
"per_second": self.rc_message.per_second, "per_second": self.rc_message.per_second,

View file

@ -30,7 +30,13 @@ from synapse.api.constants import (
JoinRules, JoinRules,
Membership, Membership,
) )
from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
EventSizeError,
SynapseError,
UnstableSpecAuthError,
)
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
EventFormatVersions, EventFormatVersions,
@ -291,7 +297,11 @@ def check_state_dependent_auth_rules(
invite_level = get_named_level(auth_dict, "invite", 0) invite_level = get_named_level(auth_dict, "invite", 0)
if user_level < invite_level: if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users") raise UnstableSpecAuthError(
403,
"You don't have permission to invite users",
errcode=Codes.INSUFFICIENT_POWER,
)
else: else:
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
return return
@ -474,7 +484,11 @@ def _is_membership_change_allowed(
return return
if not caller_in_room: # caller isn't joined if not caller_in_room: # caller isn't joined
raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id)) raise UnstableSpecAuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id),
errcode=Codes.NOT_JOINED,
)
if Membership.INVITE == membership: if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently # TODO (erikj): We should probably handle this more intelligently
@ -484,10 +498,18 @@ def _is_membership_change_allowed(
if target_banned: if target_banned:
raise AuthError(403, "%s is banned from the room" % (target_user_id,)) raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % target_user_id) raise UnstableSpecAuthError(
403,
"%s is already in the room." % target_user_id,
errcode=Codes.ALREADY_JOINED,
)
else: else:
if user_level < invite_level: if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users") raise UnstableSpecAuthError(
403,
"You don't have permission to invite users",
errcode=Codes.INSUFFICIENT_POWER,
)
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
# Joins are valid iff caller == target and: # Joins are valid iff caller == target and:
# * They are not banned. # * They are not banned.
@ -549,15 +571,27 @@ def _is_membership_change_allowed(
elif Membership.LEAVE == membership: elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks. # TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level: if target_banned and user_level < ban_level:
raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) raise UnstableSpecAuthError(
403,
"You cannot unban user %s." % (target_user_id,),
errcode=Codes.INSUFFICIENT_POWER,
)
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = get_named_level(auth_events, "kick", 50) kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level: if user_level < kick_level or user_level <= target_level:
raise AuthError(403, "You cannot kick user %s." % target_user_id) raise UnstableSpecAuthError(
403,
"You cannot kick user %s." % target_user_id,
errcode=Codes.INSUFFICIENT_POWER,
)
elif Membership.BAN == membership: elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level: if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban") raise UnstableSpecAuthError(
403,
"You don't have permission to ban",
errcode=Codes.INSUFFICIENT_POWER,
)
elif room_version.msc2403_knocking and Membership.KNOCK == membership: elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK and ( if join_rule != JoinRules.KNOCK and (
not room_version.msc3787_knock_restricted_join_rule not room_version.msc3787_knock_restricted_join_rule
@ -567,7 +601,11 @@ def _is_membership_change_allowed(
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users") raise AuthError(403, "You cannot knock for other users")
elif target_in_room: elif target_in_room:
raise AuthError(403, "You cannot knock on a room you are already in") raise UnstableSpecAuthError(
403,
"You cannot knock on a room you are already in",
errcode=Codes.ALREADY_JOINED,
)
elif caller_invited: elif caller_invited:
raise AuthError(403, "You are already invited to this room") raise AuthError(403, "You are already invited to this room")
elif target_banned: elif target_banned:
@ -638,10 +676,11 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level: if user_level < send_level:
raise AuthError( raise UnstableSpecAuthError(
403, 403,
"You don't have permission to post that to the room. " "You don't have permission to post that to the room. "
+ "user_level (%d) < send_level (%d)" % (user_level, send_level), + "user_level (%d) < send_level (%d)" % (user_level, send_level),
errcode=Codes.INSUFFICIENT_POWER,
) )
# Check state_key # Check state_key
@ -716,9 +755,10 @@ def check_historical(
historical_level = get_named_level(auth_events, "historical", 100) historical_level = get_named_level(auth_events, "historical", 100)
if user_level < historical_level: if user_level < historical_level:
raise AuthError( raise UnstableSpecAuthError(
403, 403,
'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")', 'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
errcode=Codes.INSUFFICIENT_POWER,
) )

View file

@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from typing_extensions import Literal
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.events import EventBase from synapse.events import EventBase
@ -33,7 +32,7 @@ class EventContext:
Holds information relevant to persisting an event Holds information relevant to persisting an event
Attributes: Attributes:
rejected: A rejection reason if the event was rejected, else False rejected: A rejection reason if the event was rejected, else None
_state_group: The ID of the state group for this event. Note that state events _state_group: The ID of the state group for this event. Note that state events
are persisted with a state group which includes the new event, so this is are persisted with a state group which includes the new event, so this is
@ -85,7 +84,7 @@ class EventContext:
""" """
_storage: "StorageControllers" _storage: "StorageControllers"
rejected: Union[Literal[False], str] = False rejected: Optional[str] = None
_state_group: Optional[int] = None _state_group: Optional[int] = None
state_group_before_event: Optional[int] = None state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None _state_delta_due_to_event: Optional[StateMap[str]] = None

View file

@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
) )
from synapse.federation.transport.client import SendJoinResponse from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.opentracing import trace
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -233,6 +234,7 @@ class FederationClient(FederationBase):
destination, content, timeout destination, content, timeout
) )
@trace
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]: ) -> Optional[List[EventBase]]:
@ -403,9 +405,9 @@ class FederationClient(FederationBase):
# Prime the cache # Prime the cache
self._get_pdu_cache[event.event_id] = event self._get_pdu_cache[event.event_id] = event
# FIXME: We should add a `break` here to avoid calling every # Now that we have an event, we can break out of this
# destination after we already found a PDU (will follow-up # loop and stop asking other destinations.
# in a separate PR) break
except SynapseError as e: except SynapseError as e:
logger.info( logger.info(
@ -725,6 +727,12 @@ class FederationClient(FederationBase):
if failover_errcodes is None: if failover_errcodes is None:
failover_errcodes = () failover_errcodes = ()
if not destinations:
# Give a bit of a clearer message if no servers were specified at all.
raise SynapseError(
502, f"Failed to {description} via any server: No servers specified."
)
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -774,7 +782,7 @@ class FederationClient(FederationBase):
"Failed to %s via %s", description, destination, exc_info=True "Failed to %s via %s", description, destination, exc_info=True
) )
raise SynapseError(502, "Failed to %s via any server" % (description,)) raise SynapseError(502, f"Failed to {description} via any server")
async def make_membership_event( async def make_membership_event(
self, self,

View file

@ -469,7 +469,7 @@ class FederationServer(FederationBase):
) )
for pdu in pdus_by_room[room_id]: for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id event_id = pdu.event_id
pdu_results[event_id] = e.error_dict() pdu_results[event_id] = e.error_dict(self.hs.config)
return return
for pdu in pdus_by_room[room_id]: for pdu in pdus_by_room[room_id]:
@ -843,8 +843,25 @@ class FederationServer(FederationBase):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
# Note that get_room_version throws if the room does not exist here.
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
if await self.store.is_partial_state_room(room_id):
# If our server is still only partially joined, we can't give a complete
# response to /send_join, /send_knock or /send_leave.
# This is because we will not be able to provide the server list (for partial
# joins) or the full state (for full joins).
# Return a 404 as we would if we weren't in the room at all.
logger.info(
f"Rejecting /send_{membership_type} to %s because it's a partial state room",
room_id,
)
raise SynapseError(
404,
f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
errcode=Codes.NOT_FOUND,
)
if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
raise SynapseError( raise SynapseError(
403, 403,

View file

@ -565,7 +565,7 @@ class AuthHandler:
except LoginError as e: except LoginError as e:
# this step failed. Merge the error dict into the response # this step failed. Merge the error dict into the response
# so that the client can have another go. # so that the client can have another go.
errordict = e.error_dict() errordict = e.error_dict(self.hs.config)
creds = await self.store.get_completed_ui_auth_stages(session.session_id) creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows: for f in flows:

View file

@ -59,6 +59,7 @@ from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import nested_logging_context from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import NOT_SPAM from synapse.module_api import NOT_SPAM
from synapse.replication.http.federation import ( from synapse.replication.http.federation import (
@ -180,6 +181,7 @@ class FederationHandler:
"resume_sync_partial_state_room", self._resume_sync_partial_state_room "resume_sync_partial_state_room", self._resume_sync_partial_state_room
) )
@trace
async def maybe_backfill( async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int self, room_id: str, current_depth: int, limit: int
) -> bool: ) -> bool:
@ -546,9 +548,9 @@ class FederationHandler:
) )
if ret.partial_state: if ret.partial_state:
# TODO(faster_joins): roll this back if we don't manage to start the # Mark the room as having partial state.
# background resync (eg process_remote_join fails) # The background process is responsible for unmarking this flag,
# https://github.com/matrix-org/synapse/issues/12998 # even if the join fails.
await self.store.store_partial_state_room(room_id, ret.servers_in_room) await self.store.store_partial_state_room(room_id, ret.servers_in_room)
try: try:
@ -574,17 +576,21 @@ class FederationHandler:
room_id, room_id,
) )
raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
finally:
if ret.partial_state: # Always kick off the background process that asynchronously fetches
# Kick off the process of asynchronously fetching the state for this # state for the room.
# room. # If the join failed, the background process is responsible for
run_as_background_process( # cleaning up — including unmarking the room as a partial state room.
desc="sync_partial_state_room", if ret.partial_state:
func=self._sync_partial_state_room, # Kick off the process of asynchronously fetching the state for this
initial_destination=origin, # room.
other_destinations=ret.servers_in_room, run_as_background_process(
room_id=room_id, desc="sync_partial_state_room",
) func=self._sync_partial_state_room,
initial_destination=origin,
other_destinations=ret.servers_in_room,
room_id=room_id,
)
# We wait here until this instance has seen the events come down # We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches. # replication (if we're using replication) as the below uses caches.
@ -748,6 +754,23 @@ class FederationHandler:
# (and return a 404 otherwise) # (and return a 404 otherwise)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
if await self.store.is_partial_state_room(room_id):
# If our server is still only partially joined, we can't give a complete
# response to /make_join, so return a 404 as we would if we weren't in the
# room at all.
# The main reason we can't respond properly is that we need to know about
# the auth events for the join event that we would return.
# We also should not bother entertaining the /make_join since we cannot
# handle the /send_join.
logger.info(
"Rejecting /make_join to %s because it's a partial state room", room_id
)
raise SynapseError(
404,
"Unable to handle /make_join right now; this server is not fully joined.",
errcode=Codes.NOT_FOUND,
)
# now check that we are *still* in the room # now check that we are *still* in the room
is_in_room = await self._event_auth_handler.check_host_in_room( is_in_room = await self._event_auth_handler.check_host_in_room(
room_id, self.server_name room_id, self.server_name
@ -1539,15 +1562,16 @@ class FederationHandler:
# Make an infinite iterator of destinations to try. Once we find a working # Make an infinite iterator of destinations to try. Once we find a working
# destination, we'll stick with it until it flakes. # destination, we'll stick with it until it flakes.
destinations: Collection[str]
if initial_destination is not None: if initial_destination is not None:
# Move `initial_destination` to the front of the list. # Move `initial_destination` to the front of the list.
destinations = list(other_destinations) destinations = list(other_destinations)
if initial_destination in destinations: if initial_destination in destinations:
destinations.remove(initial_destination) destinations.remove(initial_destination)
destinations = [initial_destination] + destinations destinations = [initial_destination] + destinations
destination_iter = itertools.cycle(destinations)
else: else:
destination_iter = itertools.cycle(other_destinations) destinations = other_destinations
destination_iter = itertools.cycle(destinations)
# `destination` is the current remote homeserver we're pulling from. # `destination` is the current remote homeserver we're pulling from.
destination = next(destination_iter) destination = next(destination_iter)

View file

@ -59,6 +59,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import nested_logging_context from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import ( from synapse.replication.http.federation import (
@ -278,7 +279,8 @@ class FederationEventHandler:
) )
try: try:
await self._process_received_pdu(origin, pdu, state_ids=None) context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)
except PartialStateConflictError: except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU. # The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time. # Try once more, with full state this time.
@ -286,7 +288,8 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.", "Room %s was un-partial stated while processing the PDU, trying again.",
room_id, room_id,
) )
await self._process_received_pdu(origin, pdu, state_ids=None) context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)
async def on_send_membership_event( async def on_send_membership_event(
self, origin: str, event: EventBase self, origin: str, event: EventBase
@ -316,6 +319,7 @@ class FederationEventHandler:
The event and context of the event after inserting it into the room graph. The event and context of the event after inserting it into the room graph.
Raises: Raises:
RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should computing the state at the event and persisting it. The caller should
@ -376,7 +380,7 @@ class FederationEventHandler:
# need to. # need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context) await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
await self._check_for_soft_fail(event, None, origin=origin) await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context) await self._run_push_actions_and_persist_event(event, context)
return event, context return event, context
@ -534,32 +538,36 @@ class FederationEventHandler:
# #
# This is the same operation as we do when we receive a regular event # This is the same operation as we do when we receive a regular event
# over federation. # over federation.
state_ids = await self._resolve_state_at_missing_prevs(destination, event) context = await self._compute_event_context_with_maybe_missing_prevs(
destination, event
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
) )
if context.partial_state: if context.partial_state:
# this can happen if some or all of the event's prev_events still have # this can happen if some or all of the event's prev_events still have
# partial state - ie, an event has an earlier stream_ordering than one # partial state. We were careful to only pick events from the db without
# or more of its prev_events, so we de-partial-state it before its # partial-state prev events, so that implies that a prev event has
# prev_events. # been persisted (with partial state) since we did the query.
# #
# TODO(faster_joins): we probably need to be more intelligent, and # So, let's just ignore `event` for now; when we re-run the db query
# exclude partial-state prev_events from consideration # we should instead get its partial-state prev event, which we will
# https://github.com/matrix-org/synapse/issues/13001 # de-partial-state, and then come back to event.
logger.warning( logger.warning(
"%s still has partial state: can't de-partial-state it yet", "%s still has prev_events with partial state: can't de-partial-state it yet",
event.event_id, event.event_id,
) )
return return
# since the state at this event has changed, we should now re-evaluate
# whether it should have been rejected. We must already have all of the
# auth events (from last time we went round this path), so there is no
# need to pass the origin.
await self._check_event_auth(None, event, context)
await self._store.update_state_for_partial_state_event(event, context) await self._store.update_state_for_partial_state_event(event, context)
self._state_storage_controller.notify_event_un_partial_stated( self._state_storage_controller.notify_event_un_partial_stated(
event.event_id event.event_id
) )
@trace
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None: ) -> None:
@ -604,6 +612,7 @@ class FederationEventHandler:
backfilled=True, backfilled=True,
) )
@trace
async def _get_missing_events_for_pdu( async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None: ) -> None:
@ -704,6 +713,7 @@ class FederationEventHandler:
logger.info("Got %d prev_events", len(missing_events)) logger.info("Got %d prev_events", len(missing_events))
await self._process_pulled_events(origin, missing_events, backfilled=False) await self._process_pulled_events(origin, missing_events, backfilled=False)
@trace
async def _process_pulled_events( async def _process_pulled_events(
self, origin: str, events: Iterable[EventBase], backfilled: bool self, origin: str, events: Iterable[EventBase], backfilled: bool
) -> None: ) -> None:
@ -742,6 +752,7 @@ class FederationEventHandler:
with nested_logging_context(ev.event_id): with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev, backfilled=backfilled) await self._process_pulled_event(origin, ev, backfilled=backfilled)
@trace
async def _process_pulled_event( async def _process_pulled_event(
self, origin: str, event: EventBase, backfilled: bool self, origin: str, event: EventBase, backfilled: bool
) -> None: ) -> None:
@ -806,29 +817,55 @@ class FederationEventHandler:
return return
try: try:
state_ids = await self._resolve_state_at_missing_prevs(origin, event) try:
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does context = await self._compute_event_context_with_maybe_missing_prevs(
# not return partial state origin, event
# https://github.com/matrix-org/synapse/issues/13002 )
await self._process_received_pdu(
origin,
event,
context,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event
)
await self._process_received_pdu( # We ought to have full state now, barring some unlikely race where we left and
origin, event, state_ids=state_ids, backfilled=backfilled # rejoned the room in the background.
) if context.partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
)
await self._process_received_pdu(
origin,
event,
context,
backfilled=backfilled,
)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id) logger.warning("Pulled event %s failed history check.", event_id)
else: else:
raise raise
async def _resolve_state_at_missing_prevs( async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase self, dest: str, event: EventBase
) -> Optional[StateMap[str]]: ) -> EventContext:
"""Calculate the state at an event with missing prev_events. """Build an EventContext structure for a non-outlier event whose prev_events may
be missing.
This is used when we have pulled a batch of events from a remote server, and This is used when we have pulled a batch of events from a remote server, and may
still don't have all the prev_events. not have all the prev_events.
If we already have all the prev_events for `event`, this method does nothing. To build an EventContext, we need to calculate the state before the event. If we
already have all the prev_events for `event`, we can simply use the state after
the prev_events to calculate the state before `event`.
Otherwise, the missing prevs become new backwards extremities, and we fall back Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`, to asking the remote server for the state after each missing `prev_event`,
@ -849,8 +886,7 @@ class FederationEventHandler:
event: an event to check for missing prevs. event: an event to check for missing prevs.
Returns: Returns:
if we already had all the prev events, `None`. Otherwise, returns The event context.
the event ids of the state at `event`.
Raises: Raises:
FederationError if we fail to get the state from the remote server after any FederationError if we fail to get the state from the remote server after any
@ -864,7 +900,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen missing_prevs = prevs - seen
if not missing_prevs: if not missing_prevs:
return None return await self._state_handler.compute_event_context(event)
logger.info( logger.info(
"Event %s is missing prev_events %s: calculating state for a " "Event %s is missing prev_events %s: calculating state for a "
@ -876,9 +912,15 @@ class FederationEventHandler:
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
try: try:
# Determine whether we may be about to retrieve partial state
# Events may be un-partial stated right after we compute the partial state
# flag, but that's okay, as long as the flag errs on the conservative side.
partial_state_flags = await self._store.get_partial_state_events(seen)
partial_state = any(partial_state_flags.values())
# Get the state of the events we know about # Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids( ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen room_id, seen, await_full_state=False
) )
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
@ -924,7 +966,9 @@ class FederationEventHandler:
"We can't get valid state history.", "We can't get valid state history.",
affected=event_id, affected=event_id,
) )
return state_map return await self._state_handler.compute_event_context(
event, state_ids_before_event=state_map, partial_state=partial_state
)
async def _get_state_ids_after_missing_prev_event( async def _get_state_ids_after_missing_prev_event(
self, self,
@ -1093,7 +1137,7 @@ class FederationEventHandler:
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
state_ids: Optional[StateMap[str]], context: EventContext,
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
"""Called when we have a new non-outlier event. """Called when we have a new non-outlier event.
@ -1115,24 +1159,18 @@ class FederationEventHandler:
event: event to be persisted event: event to be persisted
state_ids: Normally None, but if we are handling a gap in the graph context: The `EventContext` to persist the event with.
(ie, we are missing one or more prev_events), the resolved state at the
event. Must not be partial state.
backfilled: True if this is part of a historical batch of events (inhibits backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.) notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry computing the state at the event and persisting it. The caller should
exactly once in this case. Will never be raised if `state_ids` is provided. recompute `context` and retry exactly once when this happens.
""" """
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier assert not event.internal_metadata.outlier
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
)
try: try:
await self._check_event_auth(origin, event, context) await self._check_event_auth(origin, event, context)
except AuthError as e: except AuthError as e:
@ -1144,7 +1182,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event # For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we # passes auth based on the current state. If it doesn't then we
# "soft-fail" the event. # "soft-fail" the event.
await self._check_for_soft_fail(event, state_ids, origin=origin) await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled) await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1556,13 +1594,15 @@ class FederationEventHandler:
) )
async def _check_event_auth( async def _check_event_auth(
self, origin: str, event: EventBase, context: EventContext self, origin: Optional[str], event: EventBase, context: EventContext
) -> None: ) -> None:
""" """
Checks whether an event should be rejected (for failing auth checks). Checks whether an event should be rejected (for failing auth checks).
Args: Args:
origin: The host the event originates from. origin: The host the event originates from. This is used to fetch
any missing auth events. It can be set to None, but only if we are
sure that we already have all the auth events.
event: The event itself. event: The event itself.
context: context:
The event context. The event context.
@ -1705,7 +1745,7 @@ class FederationEventHandler:
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, self,
event: EventBase, event: EventBase,
state_ids: Optional[StateMap[str]], context: EventContext,
origin: str, origin: str,
) -> None: ) -> None:
"""Checks if we should soft fail the event; if so, marks the event as """Checks if we should soft fail the event; if so, marks the event as
@ -1716,7 +1756,7 @@ class FederationEventHandler:
Args: Args:
event event
state_ids: The state at the event if we don't have all the event's prev events context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from. origin: The host the event originates from.
""" """
if await self._store.is_partial_state_room(event.room_id): if await self._store.is_partial_state_room(event.room_id):
@ -1742,11 +1782,15 @@ class FederationEventHandler:
auth_types = auth_types_for_event(room_version_obj, event) auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state". # Calculate the "current state".
if state_ids is not None: seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
# If we're explicitly given the state then we won't have all the has_missing_prevs = bool(prev_event_ids - seen_event_ids)
# prev events, and so we have a gap in the graph. In this case if has_missing_prevs:
# we want to be a little careful as we might have been down for # We don't have all the prev_events of this event, which means we have a
# a while and have an incorrect view of the current state, # gap in the graph, and the new event is going to become a new backwards
# extremity.
#
# In this case we want to be a little careful as we might have been
# down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to # however we still want to do checks as gaps are easy to
# maliciously manufacture. # maliciously manufacture.
# #
@ -1759,6 +1803,7 @@ class FederationEventHandler:
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids) state_sets.append(state_ids)
current_state_ids = ( current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store( await self._state_resolution_handler.resolve_events_with_store(
@ -1808,7 +1853,7 @@ class FederationEventHandler:
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
async def _load_or_fetch_auth_events_for_event( async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase self, destination: Optional[str], event: EventBase
) -> Collection[EventBase]: ) -> Collection[EventBase]:
"""Fetch this event's auth_events, from database or remote """Fetch this event's auth_events, from database or remote
@ -1824,12 +1869,19 @@ class FederationEventHandler:
Args: Args:
destination: where to send the /event_auth request. Typically the server destination: where to send the /event_auth request. Typically the server
that sent us `event` in the first place. that sent us `event` in the first place.
If this is None, no attempt is made to load any missing auth events:
rather, an AssertionError is raised if there are any missing events.
event: the event whose auth_events we want event: the event whose auth_events we want
Returns: Returns:
all of the events listed in `event.auth_events_ids`, after deduplication all of the events listed in `event.auth_events_ids`, after deduplication
Raises: Raises:
AssertionError if some auth events were missing and no `destination` was
supplied.
AuthError if we were unable to fetch the auth_events for any reason. AuthError if we were unable to fetch the auth_events for any reason.
""" """
event_auth_event_ids = set(event.auth_event_ids()) event_auth_event_ids = set(event.auth_event_ids())
@ -1841,6 +1893,13 @@ class FederationEventHandler:
) )
if not missing_auth_event_ids: if not missing_auth_event_ids:
return event_auth_events.values() return event_auth_events.values()
if destination is None:
# this shouldn't happen: destination must be set unless we know we have already
# persisted the auth events.
raise AssertionError(
"_load_or_fetch_auth_events_for_event() called with no destination for "
"an event with missing auth_events"
)
logger.info( logger.info(
"Event %s refers to unknown auth events %s: fetching auth chain", "Event %s refers to unknown auth events %s: fetching auth chain",

View file

@ -143,8 +143,8 @@ class InitialSyncHandler:
joined_rooms, joined_rooms,
to_key=int(now_token.receipt_key), to_key=int(now_token.receipt_key),
) )
if self.hs.config.experimental.msc2285_enabled:
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
@ -456,11 +456,8 @@ class InitialSyncHandler:
) )
if not receipts: if not receipts:
return [] return []
if self.hs.config.experimental.msc2285_enabled:
receipts = ReceiptEventSource.filter_out_private_receipts( return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
receipts, user_id
)
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable( presence, receipts, (messages, token) = await make_deferred_yieldable(
gather_results( gather_results(

View file

@ -41,6 +41,7 @@ from synapse.api.errors import (
NotFoundError, NotFoundError,
ShadowBanError, ShadowBanError,
SynapseError, SynapseError,
UnstableSpecAuthError,
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -51,6 +52,7 @@ from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -149,7 +151,11 @@ class MessageHandler:
"Attempted to retrieve data from a room for a user that has never been in it. " "Attempted to retrieve data from a room for a user that has never been in it. "
"This should not have happened." "This should not have happened."
) )
raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN) raise UnstableSpecAuthError(
403,
"User not in room",
errcode=Codes.NOT_JOINED,
)
return data return data
@ -319,8 +325,10 @@ class MessageHandler:
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=True
) )
if membership != Membership.JOIN: if membership != Membership.JOIN:
raise NotImplementedError( raise SynapseError(
"Getting joined members after leaving is not implemented" code=403,
errcode=Codes.FORBIDDEN,
msg="Getting joined members while not being a current member of the room is forbidden.",
) )
users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
@ -334,7 +342,11 @@ class MessageHandler:
break break
else: else:
# Loop fell through, AS has no interested users in room # Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room") raise UnstableSpecAuthError(
403,
"Appservice not in room",
errcode=Codes.NOT_JOINED,
)
return { return {
user_id: { user_id: {
@ -1137,6 +1149,10 @@ class EventCreationHandler:
context = await self.state.compute_event_context( context = await self.state.compute_event_context(
event, event,
state_ids_before_event=state_map_for_event, state_ids_before_event=state_map_for_event,
# TODO(faster_joins): check how MSC2716 works and whether we can have
# partial state here
# https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
) )
else: else:
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)
@ -1366,9 +1382,10 @@ class EventCreationHandler:
# and `state_groups` because they have `prev_events` that aren't persisted yet # and `state_groups` because they have `prev_events` that aren't persisted yet
# (historical messages persisted in reverse-chronological order). # (historical messages persisted in reverse-chronological order).
if not event.internal_metadata.is_historical() and not event.content.get(EventContentFields.MSC2716_HISTORICAL): if not event.internal_metadata.is_historical() and not event.content.get(EventContentFields.MSC2716_HISTORICAL):
await self._bulk_push_rule_evaluator.action_for_event_by_user( with opentracing.start_active_span("calculate_push_actions"):
event, context await self._bulk_push_rule_evaluator.action_for_event_by_user(
) event, context
)
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
@ -1456,9 +1473,10 @@ class EventCreationHandler:
state = await state_entry.get_state( state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all() self._storage_controllers.state, StateFilter.all()
) )
joined_hosts = await self.store.get_joined_hosts( with opentracing.start_active_span("get_joined_hosts"):
event.room_id, state, state_entry joined_hosts = await self.store.get_joined_hosts(
) event.room_id, state, state_entry
)
# Note that the expiry times must be larger than the expiry time in # Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates. # _external_cache_joined_hosts_updates.

View file

@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse from synapse.handlers.room import ShutdownRoomResponse
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
@ -416,6 +417,7 @@ class PaginationHandler:
await self._storage_controllers.purge_events.purge_room(room_id) await self._storage_controllers.purge_events.purge_room(room_id)
@trace
async def get_messages( async def get_messages(
self, self,
requester: Requester, requester: Requester,

View file

@ -164,7 +164,10 @@ class ReceiptsHandler:
if not is_new: if not is_new:
return return
if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: if self.federation_sender and receipt_type not in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
):
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
@ -204,24 +207,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
for event_id, orig_event_content in room.get("content", {}).items(): for event_id, orig_event_content in room.get("content", {}).items():
event_content = orig_event_content event_content = orig_event_content
# If there are private read receipts, additional logic is necessary. # If there are private read receipts, additional logic is necessary.
if ReceiptTypes.READ_PRIVATE in event_content: if (
ReceiptTypes.READ_PRIVATE in event_content
or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content
):
# Make a copy without private read receipts to avoid leaking # Make a copy without private read receipts to avoid leaking
# other user's private read receipts.. # other user's private read receipts..
event_content = { event_content = {
receipt_type: receipt_value receipt_type: receipt_value
for receipt_type, receipt_value in event_content.items() for receipt_type, receipt_value in event_content.items()
if receipt_type != ReceiptTypes.READ_PRIVATE if receipt_type
not in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
)
} }
# Copy the current user's private read receipt from the # Copy the current user's private read receipt from the
# original content, if it exists. # original content, if it exists.
user_private_read_receipt = orig_event_content[ user_private_read_receipt = orig_event_content.get(
ReceiptTypes.READ_PRIVATE ReceiptTypes.READ_PRIVATE, {}
].get(user_id, None) ).get(user_id, None)
if user_private_read_receipt: if user_private_read_receipt:
event_content[ReceiptTypes.READ_PRIVATE] = { event_content[ReceiptTypes.READ_PRIVATE] = {
user_id: user_private_read_receipt user_id: user_private_read_receipt
} }
user_unstable_private_read_receipt = orig_event_content.get(
ReceiptTypes.UNSTABLE_READ_PRIVATE, {}
).get(user_id, None)
if user_unstable_private_read_receipt:
event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = {
user_id: user_unstable_private_read_receipt
}
# Include the event if there is at least one non-private read # Include the event if there is at least one non-private read
# receipt or the current user has a private read receipt. # receipt or the current user has a private read receipt.
@ -257,10 +274,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
room_ids, from_key=from_key, to_key=to_key room_ids, from_key=from_key, to_key=to_key
) )
if self.config.experimental.msc2285_enabled: events = ReceiptEventSource.filter_out_private_receipts(
events = ReceiptEventSource.filter_out_private_receipts( events, user.to_string()
events, user.to_string() )
)
return events, to_key return events, to_key

View file

@ -19,6 +19,7 @@ import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import _RelatedEvent from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -73,7 +74,6 @@ class RelationsHandler:
room_id: str, room_id: str,
relation_type: Optional[str] = None, relation_type: Optional[str] = None,
event_type: Optional[str] = None, event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5, limit: int = 5,
direction: str = "b", direction: str = "b",
from_token: Optional[StreamToken] = None, from_token: Optional[StreamToken] = None,
@ -89,7 +89,6 @@ class RelationsHandler:
room_id: The room the event belongs to. room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given. relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given. event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events. limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`). oldest first (`"f"`).
@ -122,7 +121,6 @@ class RelationsHandler:
room_id=room_id, room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
aggregation_key=aggregation_key,
limit=limit, limit=limit,
direction=direction, direction=direction,
from_token=from_token, from_token=from_token,
@ -364,6 +362,7 @@ class RelationsHandler:
return results return results
@trace
async def get_bundled_aggregations( async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]: ) -> Dict[str, BundledAggregations]:

View file

@ -182,7 +182,7 @@ class RoomListHandler:
== HistoryVisibility.WORLD_READABLE, == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join", "guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"], "join_rule": room["join_rules"],
"org.matrix.msc3827.room_type": room["room_type"], "room_type": room["room_type"],
} }
# Filter out Nones rather omit the field altogether # Filter out Nones rather omit the field altogether

View file

@ -32,6 +32,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM from synapse.module_api import NOT_SPAM
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -428,14 +429,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._join_rate_per_room_limiter.ratelimit( await self._join_rate_per_room_limiter.ratelimit(
requester, key=room_id, update=False requester, key=room_id, update=False
) )
with opentracing.start_active_span("handle_new_client_event"):
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, requester,
event, event,
context, context,
extra_users=[target], extra_users=[target],
ratelimit=ratelimit, ratelimit=ratelimit,
) )
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
@ -564,25 +565,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# by application services), and then by room ID. # by application services), and then by room ID.
async with self.member_as_limiter.queue(as_id): async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key): async with self.member_linearizer.queue(key):
result = await self.update_membership_locked( with opentracing.start_active_span("update_membership_locked"):
requester, result = await self.update_membership_locked(
target, requester,
room_id, target,
action, room_id,
txn_id=txn_id, action,
remote_room_hosts=remote_room_hosts, txn_id=txn_id,
third_party_signed=third_party_signed, remote_room_hosts=remote_room_hosts,
ratelimit=ratelimit, third_party_signed=third_party_signed,
content=content, ratelimit=ratelimit,
new_room=new_room, content=content,
require_consent=require_consent, new_room=new_room,
outlier=outlier, require_consent=require_consent,
historical=historical, outlier=outlier,
allow_no_prev_events=allow_no_prev_events, historical=historical,
prev_event_ids=prev_event_ids, allow_no_prev_events=allow_no_prev_events,
state_event_ids=state_event_ids, prev_event_ids=prev_event_ids,
depth=depth, state_event_ids=state_event_ids,
) depth=depth,
)
return result return result
@ -649,6 +651,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Returns: Returns:
A tuple of the new event ID and stream ID. A tuple of the new event ID and stream ID.
""" """
content_specified = bool(content) content_specified = bool(content)
if content is None: if content is None:
content = {} content = {}
@ -1679,7 +1682,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
] ]
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(
404,
"Can't join remote room because no servers "
"that are in the room have been provided.",
)
check_complexity = self.hs.config.server.limit_remote_rooms.enabled check_complexity = self.hs.config.server.limit_remote_rooms.enabled
if ( if (

View file

@ -28,11 +28,11 @@ from synapse.api.constants import (
RoomTypes, RoomTypes,
) )
from synapse.api.errors import ( from synapse.api.errors import (
AuthError,
Codes, Codes,
NotFoundError, NotFoundError,
StoreError, StoreError,
SynapseError, SynapseError,
UnstableSpecAuthError,
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
@ -175,10 +175,11 @@ class RoomSummaryHandler:
# First of all, check that the room is accessible. # First of all, check that the room is accessible.
if not await self._is_local_room_accessible(requested_room_id, requester): if not await self._is_local_room_accessible(requested_room_id, requester):
raise AuthError( raise UnstableSpecAuthError(
403, 403,
"User %s not in room %s, and room previews are disabled" "User %s not in room %s, and room previews are disabled"
% (requester, requested_room_id), % (requester, requested_room_id),
errcode=Codes.NOT_JOINED,
) )
# If this is continuing a previous session, pull the persisted data. # If this is continuing a previous session, pull the persisted data.
@ -452,7 +453,6 @@ class RoomSummaryHandler:
"type": e.type, "type": e.type,
"state_key": e.state_key, "state_key": e.state_key,
"content": e.content, "content": e.content,
"room_id": e.room_id,
"sender": e.sender, "sender": e.sender,
"origin_server_ts": e.origin_server_ts, "origin_server_ts": e.origin_server_ts,
} }

View file

@ -1535,15 +1535,13 @@ class SyncHandler:
ignored_users = await self.store.ignored_users(user_id) ignored_users = await self.store.ignored_users(user_id)
if since_token: if since_token:
room_changes = await self._get_rooms_changed( room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users, self.rooms_to_exclude sync_result_builder, ignored_users
) )
tags_by_room = await self.store.get_updated_tags( tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key user_id, since_token.account_data_key
) )
else: else:
room_changes = await self._get_all_rooms( room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
sync_result_builder, ignored_users, self.rooms_to_exclude
)
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)}) log_kv({"rooms_changed": len(room_changes.room_entries)})
@ -1622,13 +1620,14 @@ class SyncHandler:
self, self,
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str], ignored_users: FrozenSet[str],
excluded_rooms: List[str],
) -> _RoomChanges: ) -> _RoomChanges:
"""Determine the changes in rooms to report to the user. """Determine the changes in rooms to report to the user.
This function is a first pass at generating the rooms part of the sync response. This function is a first pass at generating the rooms part of the sync response.
It determines which rooms have changed during the sync period, and categorises It determines which rooms have changed during the sync period, and categorises
them into four buckets: "knock", "invite", "join" and "leave". them into four buckets: "knock", "invite", "join" and "leave". It also excludes
from that list any room that appears in the list of rooms to exclude from sync
results in the server configuration.
1. Finds all membership changes for the user in the sync period (from 1. Finds all membership changes for the user in the sync period (from
`since_token` up to `now_token`). `since_token` up to `now_token`).
@ -1654,7 +1653,7 @@ class SyncHandler:
# _have_rooms_changed. We could keep the results in memory to avoid a # _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code. # second query, at the cost of more complicated source code.
membership_change_events = await self.store.get_membership_changes_for_user( membership_change_events = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key, excluded_rooms user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
) )
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@ -1861,7 +1860,6 @@ class SyncHandler:
self, self,
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str], ignored_users: FrozenSet[str],
ignored_rooms: List[str],
) -> _RoomChanges: ) -> _RoomChanges:
"""Returns entries for all rooms for the user. """Returns entries for all rooms for the user.
@ -1883,7 +1881,7 @@ class SyncHandler:
room_list = await self.store.get_rooms_for_local_user_where_membership_is( room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, user_id=user_id,
membership_list=Membership.LIST, membership_list=Membership.LIST,
excluded_rooms=ignored_rooms, excluded_rooms=self.rooms_to_exclude,
) )
room_entries = [] room_entries = []
@ -2149,7 +2147,9 @@ class SyncHandler:
raise Exception("Unrecognized rtype: %r", room_builder.rtype) raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at( async def get_rooms_for_user_at(
self, user_id: str, room_key: RoomStreamToken self,
user_id: str,
room_key: RoomStreamToken,
) -> FrozenSet[str]: ) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering. """Get set of joined rooms for a user at the given stream ordering.
@ -2175,7 +2175,12 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream # If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room # ordering, we need to go and work out if the user was in the room
# before. # before.
# We also need to check whether the room should be excluded from sync
# responses as per the homeserver config.
for joined_room in joined_rooms: for joined_room in joined_rooms:
if joined_room.room_id in self.rooms_to_exclude:
continue
if not joined_room.event_pos.persisted_after(room_key): if not joined_room.event_pos.persisted_after(room_key):
joined_room_ids.add(joined_room.room_id) joined_room_ids.add(joined_room.room_id)
continue continue

View file

@ -489,8 +489,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
handler = self.get_typing_handler() handler = self.get_typing_handler()
events = [] events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key: # Work on a copy of things here as these may change in the handler while
# waiting for the AS `is_interested_in_room` call to complete.
# Shallow copy is safe as no nested data is present.
latest_room_serial = handler._latest_room_serial
room_serials = handler._room_serials.copy()
for room_id, serial in room_serials.items():
if serial <= from_key:
continue continue
if not await service.is_interested_in_room(room_id, self._main_store): if not await service.is_interested_in_room(room_id, self._main_store):
@ -498,7 +505,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
return events, handler._latest_room_serial return events, latest_room_serial
async def get_new_events( async def get_new_events(
self, self,

View file

@ -58,6 +58,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
@ -155,15 +156,16 @@ def is_method_cancellable(method: Callable[..., Any]) -> bool:
return getattr(method, "cancellable", False) return getattr(method, "cancellable", False)
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: def return_json_error(
f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
) -> None:
"""Sends a JSON error response to clients.""" """Sends a JSON error response to clients."""
if f.check(SynapseError): if f.check(SynapseError):
# mypy doesn't understand that f.check asserts the type. # mypy doesn't understand that f.check asserts the type.
exc: SynapseError = f.value # type: ignore exc: SynapseError = f.value # type: ignore
error_code = exc.code error_code = exc.code
error_dict = exc.error_dict() error_dict = exc.error_dict(config)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
elif f.check(CancelledError): elif f.check(CancelledError):
error_code = HTTP_STATUS_REQUEST_CANCELLED error_code = HTTP_STATUS_REQUEST_CANCELLED
@ -450,7 +452,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest, request: SynapseRequest,
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response""" """Implements _AsyncResource._send_error_response"""
return_json_error(f, request) return_json_error(f, request, None)
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -575,6 +577,14 @@ class JsonResource(DirectServeJsonResource):
return callback_return return callback_return
def _send_error_response(
self,
f: failure.Failure,
request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request, self.hs.config)
class DirectServeHtmlResource(_AsyncResource): class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests, """A resource that will call `self._async_on_<METHOD>` on new requests,

View file

@ -901,6 +901,11 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
def tag_args(func: Callable[P, R]) -> Callable[P, R]: def tag_args(func: Callable[P, R]) -> Callable[P, R]:
""" """
Tags all of the args to the active span. Tags all of the args to the active span.
Args:
func: `func` is assumed to be a method taking a `self` parameter, or a
`classmethod` taking a `cls` parameter. In either case, a tag is not
created for this parameter.
""" """
if not opentracing: if not opentracing:
@ -909,8 +914,14 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func) @wraps(func)
def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
argspec = inspect.getfullargspec(func) argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]): # We use `[1:]` to skip the `self` object reference and `start=1` to
set_tag("ARG_" + arg, str(args[i])) # type: ignore[index] # make the index line up with `argspec.args`.
#
# FIXME: We could update this handle any type of function by ignoring the
# first argument only if it's named `self` or `cls`. This isn't fool-proof
# but handles the idiomatic cases.
for i, arg in enumerate(args[1:], start=1): # type: ignore[index]
set_tag("ARG_" + argspec.args[i], str(arg))
set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index] set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index]
set_tag("kwargs", str(kwargs)) set_tag("kwargs", str(kwargs))
return func(*args, **kwargs) return func(*args, **kwargs)

View file

@ -929,10 +929,12 @@ class ModuleApi:
room_id: str, room_id: str,
new_membership: str, new_membership: str,
content: Optional[JsonDict] = None, content: Optional[JsonDict] = None,
remote_room_hosts: Optional[List[str]] = None,
) -> EventBase: ) -> EventBase:
"""Updates the membership of a user to the given value. """Updates the membership of a user to the given value.
Added in Synapse v1.46.0. Added in Synapse v1.46.0.
Changed in Synapse v1.65.0: Added the 'remote_room_hosts' parameter.
Args: Args:
sender: The user performing the membership change. Must be a user local to sender: The user performing the membership change. Must be a user local to
@ -946,6 +948,7 @@ class ModuleApi:
https://spec.matrix.org/unstable/client-server-api/#mroommember for the https://spec.matrix.org/unstable/client-server-api/#mroommember for the
list of allowed values. list of allowed values.
content: Additional values to include in the resulting event's content. content: Additional values to include in the resulting event's content.
remote_room_hosts: Remote servers to use for remote joins/knocks/etc.
Returns: Returns:
The newly created membership event. The newly created membership event.
@ -1005,15 +1008,12 @@ class ModuleApi:
room_id=room_id, room_id=room_id,
action=new_membership, action=new_membership,
content=content, content=content,
remote_room_hosts=remote_room_hosts,
) )
# Try to retrieve the resulting event. # Try to retrieve the resulting event.
event = await self._hs.get_datastores().main.get_event(event_id) event = await self._hs.get_datastores().main.get_event(event_id)
# update_membership is supposed to always return after the event has been
# successfully persisted.
assert event is not None
return event return event
async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase: async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
@ -1452,6 +1452,81 @@ class ModuleApi:
start_timestamp, end_timestamp start_timestamp, end_timestamp
) )
async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]:
"""
Get the room ID associated with a room alias.
Added in Synapse v1.65.0.
Args:
room_alias: The alias to look up.
Returns:
A tuple of:
The room ID (str).
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias is invalid or could not be found.
"""
alias = RoomAlias.from_string(room_alias)
(room_id, hosts) = await self._hs.get_room_member_handler().lookup_room_alias(
alias
)
return room_id.to_string(), hosts
async def create_room(
self,
user_id: str,
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[str, Optional[str]]:
"""Creates a new room.
Added in Synapse v1.65.0.
Args:
user_id:
The user who requested the room creation.
config : A dict of configuration options. See "Request body" of:
https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
ratelimit: set to False to disable the rate limiter for this specific operation.
creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
values to go in the body of the 'join' event (typically
`avatar_url` and/or `displayname`.
Returns:
A tuple containing: 1) the room ID (str), 2) if an alias was requested,
the room alias (str), otherwise None if no alias was requested.
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded.
RuntimeError if the user_id does not refer to a local user.
SynapseError if the user_id is invalid, room ID couldn't be stored, or
something went horribly wrong.
"""
if not self.is_mine(user_id):
raise RuntimeError(
"Tried to create a room as a user that isn't local to this homeserver",
)
requester = create_requester(user_id)
room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
requester=requester,
config=config,
ratelimit=ratelimit,
creator_join_profile=creator_join_profile,
)
return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
class PublicRoomListManager: class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room """Contains methods for adding to, removing from and querying whether a room

View file

@ -416,7 +416,10 @@ class FederationSenderHandler:
if not self._is_mine_id(receipt.user_id): if not self._is_mine_id(receipt.user_id):
continue continue
# Private read receipts never get sent over federation. # Private read receipts never get sent over federation.
if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: if receipt.receipt_type in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
):
continue continue
receipt_info = ReadReceipt( receipt_info = ReadReceipt(
receipt.room_id, receipt.room_id,

View file

@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet):
) )
receipts_by_room = await self.store.get_receipts_for_user_with_orderings( receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] user_id,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
notif_event_ids = [pa.event_id for pa in push_actions] notif_event_ids = [pa.event_id for pa in push_actions]

View file

@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} self._known_receipt_types = {
ReceiptTypes.READ,
ReceiptTypes.FULLY_READ,
ReceiptTypes.READ_PRIVATE,
}
if hs.config.experimental.msc2285_enabled: if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str

View file

@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ} self._known_receipt_types = {
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.FULLY_READ,
}
if hs.config.experimental.msc2285_enabled: if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.update( self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str

View file

@ -33,7 +33,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
@ -325,7 +325,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter( self.ratelimiter = FederationRateLimiter(
hs.get_clock(), hs.get_clock(),
FederationRateLimitConfig( FederationRatelimitSettings(
# Time window of 2s # Time window of 2s
window_size=2000, window_size=2000,
# Artificially delay requests if rate > sleep_limit/window_size # Artificially delay requests if rate > sleep_limit/window_size

View file

@ -94,9 +94,10 @@ class VersionsRestServlet(RestServlet):
# Supports the busy presence state described in MSC3026. # Supports the busy presence state described in MSC3026.
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285 # Supports receiving private read receipts as per MSC2285
"org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec
"org.matrix.msc2285": self.config.experimental.msc2285_enabled, "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
# Supports filtering of /publicRooms by room type MSC3827 # Supports filtering of /publicRooms by room type as per MSC3827
"org.matrix.msc3827": self.config.experimental.msc3827_enabled, "org.matrix.msc3827.stable": True,
# Adds support for importing historical messages as per MSC2716 # Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled, "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030

View file

@ -255,7 +255,7 @@ class StateHandler:
self, self,
event: EventBase, event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None, state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False, partial_state: Optional[bool] = None,
) -> EventContext: ) -> EventContext:
"""Build an EventContext structure for a non-outlier event. """Build an EventContext structure for a non-outlier event.
@ -270,10 +270,18 @@ class StateHandler:
it can't be calculated from existing events. This is normally it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling. don't have the prev events, e.g. when backfilling.
partial_state: True if `state_ids_before_event` is partial and omits partial_state:
non-critical membership events `True` if `state_ids_before_event` is partial and omits non-critical
membership events.
`False` if `state_ids_before_event` is the full state.
`None` when `state_ids_before_event` is not provided. In this case, the
flag will be calculated based on `event`'s prev events.
Returns: Returns:
The event context. The event context.
Raises:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
""" """
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()
@ -298,12 +306,14 @@ class StateHandler:
) )
) )
# the partial_state flag must be provided
assert partial_state is not None
else: else:
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case: # partial_state should not be set explicitly in this case:
# we work it out dynamically # we work it out dynamically
assert not partial_state assert partial_state is None
# if any of the prev-events have partial state, so do we. # if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use # (This is slightly racy - the prev-events might get fixed up before we use
@ -313,13 +323,13 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events( incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids prev_event_ids
) )
if any(incomplete_prev_events.values()): partial_state = any(incomplete_prev_events.values())
if partial_state:
logger.debug( logger.debug(
"New/incoming event %s refers to prev_events %s with partial state", "New/incoming event %s refers to prev_events %s with partial state",
event.event_id, event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v], [k for (k, v) in incomplete_prev_events.items() if v],
) )
partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for # we've already taken into account partial state, so no need to wait for
@ -426,6 +436,10 @@ class StateHandler:
Returns: Returns:
The resolved state The resolved state
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)

View file

@ -434,7 +434,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: str, event_id: str,
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore, state_res_store: StateResolutionStore,
auth_diff: Set[str], full_conflicted_set: Set[str],
) -> None: ) -> None:
"""Helper function for _reverse_topological_power_sort that add the event """Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph and its auth chain (that is in the auth diff) to the graph
@ -445,7 +445,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: Event to add to the graph event_id: Event to add to the graph
event_map event_map
state_res_store state_res_store
auth_diff: Set of event IDs that are in the auth difference. full_conflicted_set: Set of event IDs that are in the full conflicted set.
""" """
state = [event_id] state = [event_id]
@ -455,7 +455,7 @@ async def _add_event_and_auth_chain_to_graph(
event = await _get_event(room_id, eid, event_map, state_res_store) event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
if aid in auth_diff: if aid in full_conflicted_set:
if aid not in graph: if aid not in graph:
state.append(aid) state.append(aid)
@ -468,7 +468,7 @@ async def _reverse_topological_power_sort(
event_ids: Iterable[str], event_ids: Iterable[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore, state_res_store: StateResolutionStore,
auth_diff: Set[str], full_conflicted_set: Set[str],
) -> List[str]: ) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering, """Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts and then by power level and origin_server_ts
@ -479,7 +479,7 @@ async def _reverse_topological_power_sort(
event_ids: The events to sort event_ids: The events to sort
event_map event_map
state_res_store state_res_store
auth_diff: Set of event IDs that are in the auth difference. full_conflicted_set: Set of event IDs that are in the full conflicted set.
Returns: Returns:
The sorted list The sorted list
@ -488,7 +488,7 @@ async def _reverse_topological_power_sort(
graph: Dict[str, Set[str]] = {} graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph( await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
) )
# We await occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to

View file

@ -29,6 +29,7 @@ from typing import (
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.opentracing import trace
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import ( from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker, PartialCurrentStateTracker,
@ -82,13 +83,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str] self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
Args: Args:
_room_id: id of the room for these events _room_id: id of the room for these events
event_ids: ids of the events event_ids: ids of the events
await_full_state: if `True`, will block if we do not yet have complete
state at these events.
Returns: Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
@ -100,7 +103,9 @@ class StateStorageController:
if not event_ids: if not event_ids:
return {} return {}
event_to_groups = await self.get_state_group_for_events(event_ids) event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups) group_to_state = await self.stores.state._get_state_for_groups(groups)
@ -175,6 +180,7 @@ class StateStorageController:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@trace
async def get_state_for_events( async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
@ -221,6 +227,7 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@trace
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, self,
event_ids: Collection[str], event_ids: Collection[str],
@ -283,6 +290,7 @@ class StateStorageController:
) )
return state_map[event_id] return state_map[event_id]
@trace
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:
@ -323,6 +331,7 @@ class StateStorageController:
groups, state_filter or StateFilter.all() groups, state_filter or StateFilter.all()
) )
@trace
async def get_state_group_for_events( async def get_state_group_for_events(
self, self,
event_ids: Collection[str], event_ids: Collection[str],
@ -334,6 +343,10 @@ class StateStorageController:
event_ids: events to get state groups for event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete await_full_state: if true, will block if we do not yet have complete
state at these events. state at these events.
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
""" """
if await_full_state: if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids) await self._partial_state_events_tracker.await_full_state(event_ids)

View file

@ -12,6 +12,67 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Responsible for storing and fetching push actions / notifications.
There are two main uses for push actions:
1. Sending out push to a user's device; and
2. Tracking per-room per-user notification counts (used in sync requests).
For the former we simply use the `event_push_actions` table, which contains all
the calculated actions for a given user (which were calculated by the
`BulkPushRuleEvaluator`).
For the latter we could simply count the number of rows in `event_push_actions`
table for a given room/user, but in practice this is *very* heavyweight when
there were a large number of notifications (due to e.g. the user never reading a
room). Plus, keeping all push actions indefinitely uses a lot of disk space.
To fix these issues, we add a new table `event_push_summary` that tracks
per-user per-room counts of all notifications that happened before a stream
ordering S. Thus, to get the notification count for a user / room we can simply
query a single row in `event_push_summary` and count the number of rows in
`event_push_actions` with a stream ordering larger than S (and as long as S is
"recent", the number of rows needing to be scanned will be small).
The `event_push_summary` table is updated via a background job that periodically
chooses a new stream ordering S' (usually the latest stream ordering), counts
all notifications in `event_push_actions` between the existing S and S', and
adds them to the existing counts in `event_push_summary`.
This allows us to delete old rows from `event_push_actions` once those rows have
been counted and added to `event_push_summary` (we call this process
"rotation").
We need to handle when a user sends a read receipt to the room. Again this is
done as a background process. For each receipt we clear the row in
`event_push_summary` and count the number of notifications in
`event_push_actions` that happened after the receipt but before S, and insert
that count into `event_push_summary` (If the receipt happened *after* S then we
simply clear the `event_push_summary`.)
Note that its possible that if the read receipt is for an old event the relevant
`event_push_actions` rows will have been rotated and we get the wrong count
(it'll be too low). We accept this as a rare edge case that is unlikely to
impact the user much (since the vast majority of read receipts will be for the
latest event).
The last complication is to handle the race where we request the notifications
counts after a user sends a read receipt into the room, but *before* the
background update handles the receipt (without any special handling the counts
would be outdated). We fix this by including in `event_push_summary` the read
receipt we used when updating `event_push_summary`, and every time we query the
table we check if that matches the most recent read receipt in the room. If yes,
continue as above, if not we simply query the `event_push_actions` table
directly.
Since read receipts are almost always for recent events, scanning the
`event_push_actions` table in this case is unlikely to be a problem. Even if it
is a problem, it is temporary until the background job handles the new read
receipt.
"""
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
@ -19,7 +80,7 @@ import attr
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
@ -198,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn, txn,
user_id, user_id,
room_id, room_id,
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
) )
stream_ordering = None stream_ordering = None
@ -265,7 +330,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
counts.notify_count += row[1] counts.notify_count += row[1]
counts.unread_count += row[2] counts.unread_count += row[2]
# Next we need to count highlights, which aren't summarized # Next we need to count highlights, which aren't summarised
sql = """ sql = """
SELECT COUNT(*) FROM event_push_actions SELECT COUNT(*) FROM event_push_actions
WHERE user_id = ? WHERE user_id = ?
@ -280,7 +345,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Finally we need to count push actions that aren't included in the # Finally we need to count push actions that aren't included in the
# summary returned above, e.g. recent events that haven't been # summary returned above, e.g. recent events that haven't been
# summarized yet, or the summary is empty due to a recent read receipt. # summarised yet, or the summary is empty due to a recent read receipt.
stream_ordering = max(stream_ordering, summary_stream_ordering) stream_ordering = max(stream_ordering, summary_stream_ordering)
notify_count, unread_count = self._get_notif_unread_count_for_user_room( notify_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering txn, room_id, user_id, stream_ordering
@ -304,6 +369,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Does not consult `event_push_summary` table, which may include push Does not consult `event_push_summary` table, which may include push
actions that have been deleted from `event_push_actions` table. actions that have been deleted from `event_push_actions` table.
Args:
txn: The database transaction.
room_id: The room ID to get unread counts for.
user_id: The user ID to get unread counts for.
stream_ordering: The (exclusive) minimum stream ordering to consider.
max_stream_ordering: The (inclusive) maximum stream ordering to consider.
If this is not given, then no maximum is applied.
Return:
A tuple of the notif count and unread count in the given range.
""" """
# If there have been no events in the room since the stream ordering, # If there have been no events in the room since the stream ordering,
@ -376,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by ascending stream_ordering. The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
def get_after_receipt( def get_after_receipt(
@ -383,28 +460,41 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) -> List[Tuple[str, str, int, str, bool]]: ) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," receipt_types_clause, args = make_in_list_sql_clause(
" ep.highlight " self.database_engine,
" FROM (" "receipt_type",
" SELECT room_id," (
" MAX(stream_ordering) as stream_ordering" ReceiptTypes.READ,
" FROM events" ReceiptTypes.READ_PRIVATE,
" INNER JOIN receipts_linearized USING (room_id, event_id)" ReceiptTypes.UNSTABLE_READ_PRIVATE,
" WHERE receipt_type = 'm.read' AND user_id = ?" ),
" GROUP BY room_id" )
") AS rl,"
" event_push_actions AS ep" sql = f"""
" WHERE" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
" ep.room_id = rl.room_id" ep.highlight
" AND ep.stream_ordering > rl.stream_ordering" FROM (
" AND ep.user_id = ?" SELECT room_id,
" AND ep.stream_ordering > ?" MAX(stream_ordering) as stream_ordering
" AND ep.stream_ordering <= ?" FROM events
" AND ep.notif = 1" INNER JOIN receipts_linearized USING (room_id, event_id)
" ORDER BY ep.stream_ordering ASC LIMIT ?" WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id
) AS rl,
event_push_actions AS ep
WHERE
ep.room_id = rl.room_id
AND ep.stream_ordering > rl.stream_ordering
AND ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@ -418,24 +508,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt( def get_no_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]: ) -> List[Tuple[str, str, int, str, bool]]:
sql = ( receipt_types_clause, args = make_in_list_sql_clause(
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," self.database_engine,
" ep.highlight " "receipt_type",
" FROM event_push_actions AS ep" (
" INNER JOIN events AS e USING (room_id, event_id)" ReceiptTypes.READ,
" WHERE" ReceiptTypes.READ_PRIVATE,
" ep.room_id NOT IN (" ReceiptTypes.UNSTABLE_READ_PRIVATE,
" SELECT room_id FROM receipts_linearized" ),
" WHERE receipt_type = 'm.read' AND user_id = ?" )
" GROUP BY room_id"
" )" sql = f"""
" AND ep.user_id = ?" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
" AND ep.stream_ordering > ?" ep.highlight
" AND ep.stream_ordering <= ?" FROM event_push_actions AS ep
" AND ep.notif = 1" INNER JOIN events AS e USING (room_id, event_id)
" ORDER BY ep.stream_ordering ASC LIMIT ?" WHERE
ep.room_id NOT IN (
SELECT room_id FROM receipts_linearized
WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id
)
AND ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@ -485,34 +587,47 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by descending received_ts. The list will be ordered by descending received_ts.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
# find rooms that have a read receipt in them and return the most recent # find rooms that have a read receipt in them and return the most recent
# push actions # push actions
def get_after_receipt( def get_after_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]: ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = ( receipt_types_clause, args = make_in_list_sql_clause(
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," self.database_engine,
" ep.highlight, e.received_ts" "receipt_type",
" FROM (" (
" SELECT room_id," ReceiptTypes.READ,
" MAX(stream_ordering) as stream_ordering" ReceiptTypes.READ_PRIVATE,
" FROM events" ReceiptTypes.UNSTABLE_READ_PRIVATE,
" INNER JOIN receipts_linearized USING (room_id, event_id)" ),
" WHERE receipt_type = 'm.read' AND user_id = ?" )
" GROUP BY room_id"
") AS rl," sql = f"""
" event_push_actions AS ep" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
" INNER JOIN events AS e USING (room_id, event_id)" ep.highlight, e.received_ts
" WHERE" FROM (
" ep.room_id = rl.room_id" SELECT room_id,
" AND ep.stream_ordering > rl.stream_ordering" MAX(stream_ordering) as stream_ordering
" AND ep.user_id = ?" FROM events
" AND ep.stream_ordering > ?" INNER JOIN receipts_linearized USING (room_id, event_id)
" AND ep.stream_ordering <= ?" WHERE {receipt_types_clause} AND user_id = ?
" AND ep.notif = 1" GROUP BY room_id
" ORDER BY ep.stream_ordering DESC LIMIT ?" ) AS rl,
event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
ep.room_id = rl.room_id
AND ep.stream_ordering > rl.stream_ordering
AND ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@ -526,24 +641,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt( def get_no_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]: ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = ( receipt_types_clause, args = make_in_list_sql_clause(
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," self.database_engine,
" ep.highlight, e.received_ts" "receipt_type",
" FROM event_push_actions AS ep" (
" INNER JOIN events AS e USING (room_id, event_id)" ReceiptTypes.READ,
" WHERE" ReceiptTypes.READ_PRIVATE,
" ep.room_id NOT IN (" ReceiptTypes.UNSTABLE_READ_PRIVATE,
" SELECT room_id FROM receipts_linearized" ),
" WHERE receipt_type = 'm.read' AND user_id = ?" )
" GROUP BY room_id"
" )" sql = f"""
" AND ep.user_id = ?" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
" AND ep.stream_ordering > ?" ep.highlight, e.received_ts
" AND ep.stream_ordering <= ?" FROM event_push_actions AS ep
" AND ep.notif = 1" INNER JOIN events AS e USING (room_id, event_id)
" ORDER BY ep.stream_ordering DESC LIMIT ?" WHERE
ep.room_id NOT IN (
SELECT room_id FROM receipts_linearized
WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id
)
AND ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@ -769,12 +896,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# [10, <none>, 20], we should treat this as being equivalent to # [10, <none>, 20], we should treat this as being equivalent to
# [10, 10, 20]. # [10, 10, 20].
# #
sql = ( sql = """
"SELECT received_ts FROM events" SELECT received_ts FROM events
" WHERE stream_ordering <= ?" WHERE stream_ordering <= ?
" ORDER BY stream_ordering DESC" ORDER BY stream_ordering DESC
" LIMIT 1" LIMIT 1
) """
while range_end - range_start > 0: while range_end - range_start > 0:
middle = (range_end + range_start) // 2 middle = (range_end + range_start) // 2
@ -802,14 +929,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self, stream_ordering: int self, stream_ordering: int
) -> Optional[int]: ) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
sql = ( sql = """
"SELECT e.received_ts" SELECT e.received_ts
" FROM event_push_actions AS ep" FROM event_push_actions AS ep
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
" WHERE ep.stream_ordering > ? AND notif = 1" WHERE ep.stream_ordering > ? AND notif = 1
" ORDER BY ep.stream_ordering ASC" ORDER BY ep.stream_ordering ASC
" LIMIT 1" LIMIT 1
) """
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return cast(Optional[Tuple[int]], txn.fetchone()) return cast(Optional[Tuple[int]], txn.fetchone())
@ -858,10 +985,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Any push actions which predate the user's most recent read receipt are Any push actions which predate the user's most recent read receipt are
now redundant, so we can remove them from `event_push_actions` and now redundant, so we can remove them from `event_push_actions` and
update `event_push_summary`. update `event_push_summary`.
Returns true if all new receipts have been processed.
""" """
limit = 100 limit = 100
# The (inclusive) receipt stream ID that was previously processed..
min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn( min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="event_push_summary_last_receipt_stream_id", table="event_push_summary_last_receipt_stream_id",
@ -871,6 +1001,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
max_receipts_stream_id = self._receipts_id_gen.get_current_token() max_receipts_stream_id = self._receipts_id_gen.get_current_token()
# The (inclusive) event stream ordering that was previously summarised.
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
sql = """ sql = """
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
FROM receipts_linearized AS r FROM receipts_linearized AS r
@ -895,13 +1033,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
rows = txn.fetchall() rows = txn.fetchall()
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
# For each new read receipt we delete push actions from before it and # For each new read receipt we delete push actions from before it and
# recalculate the summary. # recalculate the summary.
for _, room_id, user_id, stream_ordering in rows: for _, room_id, user_id, stream_ordering in rows:
@ -920,10 +1051,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering), (room_id, user_id, stream_ordering),
) )
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room( notif_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
) )
# Replace the previous summary with the new counts.
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
@ -956,10 +1090,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return len(rows) < limit return len(rows) < limit
def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
"""Archives older notifications into event_push_summary. Returns whether """Archives older notifications (from event_push_actions) into event_push_summary.
the archiving process has caught up or not.
Returns whether the archiving process has caught up or not.
""" """
# The (inclusive) event stream ordering that was previously summarised.
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="event_push_summary_stream_ordering", table="event_push_summary_stream_ordering",
@ -974,7 +1110,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ? WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
""", """,
(old_rotate_stream_ordering, self._rotate_count), (old_rotate_stream_ordering, self._rotate_count),
) )
stream_row = txn.fetchone() stream_row = txn.fetchone()
@ -993,19 +1129,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) self._rotate_notifs_before_txn(
txn, old_rotate_stream_ordering, rotate_to_stream_ordering
)
return caught_up return caught_up
def _rotate_notifs_before_txn( def _rotate_notifs_before_txn(
self, txn: LoggingTransaction, rotate_to_stream_ordering: int self,
txn: LoggingTransaction,
old_rotate_stream_ordering: int,
rotate_to_stream_ordering: int,
) -> None: ) -> None:
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( """Archives older notifications (from event_push_actions) into event_push_summary.
txn,
table="event_push_summary_stream_ordering", Any event_push_actions between old_rotate_stream_ordering (exclusive) and
keyvalues={}, rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
retcol="stream_ordering", table.
)
Args:
txn: The database transaction.
old_rotate_stream_ordering: The previous maximum event stream ordering.
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
Returns whether the archiving process has caught up or not.
"""
# Calculate the new counts that should be upserted into event_push_summary # Calculate the new counts that should be upserted into event_push_summary
sql = """ sql = """
@ -1093,9 +1241,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def _remove_old_push_actions_that_have_rotated( async def _remove_old_push_actions_that_have_rotated(
self, self,
) -> None: ) -> None:
"""Clear out old push actions that have been summarized.""" """Clear out old push actions that have been summarised."""
# We want to clear out anything that older than a day that *has* already # We want to clear out anything that is older than a day that *has* already
# been rotated. # been rotated.
rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol( rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
table="event_push_summary_stream_ordering", table="event_push_summary_stream_ordering",
@ -1119,7 +1267,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering <= ? AND highlight = 0 WHERE stream_ordering <= ? AND highlight = 0
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
""", """,
( (
max_stream_ordering_to_delete, max_stream_ordering_to_delete,
batch_size, batch_size,
@ -1215,16 +1363,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# NB. This assumes event_ids are globally unique since # NB. This assumes event_ids are globally unique since
# it makes the query easier to index # it makes the query easier to index
sql = ( sql = """
"SELECT epa.event_id, epa.room_id," SELECT epa.event_id, epa.room_id,
" epa.stream_ordering, epa.topological_ordering," epa.stream_ordering, epa.topological_ordering,
" epa.actions, epa.highlight, epa.profile_tag, e.received_ts" epa.actions, epa.highlight, epa.profile_tag, e.received_ts
" FROM event_push_actions epa, events e" FROM event_push_actions epa, events e
" WHERE epa.event_id = e.event_id" WHERE epa.event_id = e.event_id
" AND epa.user_id = ? %s" AND epa.user_id = ? %s
" AND epa.notif = 1" AND epa.notif = 1
" ORDER BY epa.stream_ordering DESC" ORDER BY epa.stream_ordering DESC
" LIMIT ?" % (before_clause,) LIMIT ?
""" % (
before_clause,
) )
txn.execute(sql, args) txn.execute(sql, args)
return cast( return cast(

View file

@ -1490,7 +1490,7 @@ class PersistEventsStore:
event.sender, event.sender,
"url" in event.content and isinstance(event.content["url"], str), "url" in event.content and isinstance(event.content["url"], str),
event.get_state_key(), event.get_state_key(),
context.rejected or None, context.rejected,
) )
for event, context in events_and_contexts for event, context in events_and_contexts
), ),

View file

@ -600,7 +600,11 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
map from event id to result map from event id to result
""" """
event_entry_map = await self._get_events_from_cache( # Shortcut: check if we have any events in the *in memory* cache - this function
# may be called repeatedly for the same event so at this point we cannot reach
# out to any external cache for performance reasons. The external cache is
# checked later on in the `get_missing_events_from_cache_or_db` function below.
event_entry_map = self._get_events_from_local_cache(
event_ids, event_ids,
) )
@ -632,7 +636,9 @@ class EventsWorkerStore(SQLBaseStore):
if missing_events_ids: if missing_events_ids:
async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: async def get_missing_events_from_cache_or_db() -> Dict[
str, EventCacheEntry
]:
"""Fetches the events in `missing_event_ids` from the database. """Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow Also creates entries in `self._current_event_fetches` to allow
@ -657,10 +663,18 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event # the events have been redacted, and if so pulling the redaction event
# out of the database to check it. # out of the database to check it.
# #
missing_events = {}
try: try:
missing_events = await self._get_events_from_db( # Try to fetch from any external cache. We already checked the
# in-memory cache above.
missing_events = await self._get_events_from_external_cache(
missing_events_ids, missing_events_ids,
) )
# Now actually fetch any remaining events from the DB
db_missing_events = await self._get_events_from_db(
missing_events_ids - missing_events.keys(),
)
missing_events.update(db_missing_events)
except Exception as e: except Exception as e:
with PreserveLoggingContext(): with PreserveLoggingContext():
fetching_deferred.errback(e) fetching_deferred.errback(e)
@ -679,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore):
# cancellations, since multiple `_get_events_from_cache_or_db` calls can # cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch. # reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation( missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
get_missing_events_from_db() get_missing_events_from_cache_or_db()
) )
event_entry_map.update(missing_events) event_entry_map.update(missing_events)
@ -754,7 +768,54 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache( async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]: ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches. """Fetch events from the caches, both in memory and any external.
May return rejected events.
Args:
events: list of event_ids to fetch
update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = self._get_events_from_local_cache(
events, update_metrics=update_metrics
)
missing_event_ids = (e for e in events if e not in event_map)
event_map.update(
await self._get_events_from_external_cache(
events=missing_event_ids,
update_metrics=update_metrics,
)
)
return event_map
async def _get_events_from_external_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
"""Fetch events from any configured external cache.
May return rejected events.
Args:
events: list of event_ids to fetch
update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = {}
for event_id in events:
ret = await self._get_event_cache.get_external(
(event_id,), None, update_metrics=update_metrics
)
if ret:
event_map[event_id] = ret
return event_map
def _get_events_from_local_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
"""Fetch events from the local, in memory, caches.
May return rejected events. May return rejected events.
@ -766,7 +827,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events: for event_id in events:
# First check if it's in the event cache # First check if it's in the event cache
ret = await self._get_event_cache.get( ret = self._get_event_cache.get_local(
(event_id,), None, update_metrics=update_metrics (event_id,), None, update_metrics=update_metrics
) )
if ret: if ret:
@ -788,7 +849,7 @@ class EventsWorkerStore(SQLBaseStore):
# We add the entry back into the cache as we want to keep # We add the entry back into the cache as we want to keep
# recently queried events in the cache. # recently queried events in the cache.
await self._get_event_cache.set((event_id,), cache_entry) self._get_event_cache.set_local((event_id,), cache_entry)
return event_map return event_map
@ -2110,11 +2171,29 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn( def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str txn: LoggingTransaction, room_id: str
) -> List[str]: ) -> List[str]:
# we want to work through the events from oldest to newest, so
# we only want events whose prev_events do *not* have partial state - hence
# the 'NOT EXISTS' clause in the below.
#
# This is necessary because ordering by stream ordering isn't quite enough
# to ensure that we work from oldest to newest event (in particular,
# if an event is initially persisted as an outlier and later de-outliered,
# it can end up with a lower stream_ordering than its prev_events).
#
# Typically this means we'll only return one event per batch, but that's
# hard to do much about.
#
# See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute( txn.execute(
""" """
SELECT event_id FROM partial_state_events AS pse SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id) JOIN events USING (event_id)
WHERE pse.room_id = ? WHERE pse.room_id = ? AND
NOT EXISTS(
SELECT 1 FROM event_edges AS ee
JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
WHERE ee.event_id=pse.event_id
)
ORDER BY events.stream_ordering ORDER BY events.stream_ordering
LIMIT 100 LIMIT 100
""", """,

View file

@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str, room_id: str,
relation_type: Optional[str] = None, relation_type: Optional[str] = None,
event_type: Optional[str] = None, event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5, limit: int = 5,
direction: str = "b", direction: str = "b",
from_token: Optional[StreamToken] = None, from_token: Optional[StreamToken] = None,
@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to. room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given. relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given. event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events. limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`). oldest first (`"f"`).
@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?") where_clause.append("type = ?")
where_args.append(event_type) where_args.append(event_type)
if aggregation_key:
where_clause.append("aggregation_key = ?")
where_args.append(aggregation_key)
pagination_clause = generate_pagination_where_clause( pagination_clause = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("topological_ordering", "stream_ordering"),

View file

@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause( def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None] self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]: ) -> Tuple[Union[str, None], List[str]]:
if not room_types or not self.config.experimental.msc3827_enabled: if not room_types:
return None, [] return None, []
else: else:
# We use None when we want get rooms without a type # We use None when we want get rooms without a type

View file

@ -896,7 +896,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off # We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we # the hit ratio counts. After all, we don't populate the cache if we
# miss it here # miss it here
event_map = await self._get_events_from_cache( event_map = self._get_events_from_local_cache(
member_event_ids, update_metrics=False member_event_ids, update_metrics=False
) )

View file

@ -419,13 +419,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its # anything that was rejected should have the same state as its
# predecessor. # predecessor.
if context.rejected: if context.rejected:
assert context.state_group == context.state_group_before_event state_group = context.state_group_before_event
else:
state_group = context.state_group
self.db_pool.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
keyvalues={"event_id": event.event_id}, keyvalues={"event_id": event.event_id},
updatevalues={"state_group": context.state_group}, updatevalues={"state_group": state_group},
) )
self.db_pool.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
@ -440,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.call_after( txn.call_after(
self._get_state_group_for_event.prefill, self._get_state_group_for_event.prefill,
(event.event_id,), (event.event_id,),
context.state_group, state_group,
) )

View file

@ -58,6 +58,7 @@ from twisted.internet import defer
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -1346,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, next_token return rows, next_token
@trace
async def paginate_room_events( async def paginate_room_events(
self, self,
room_id: str, room_id: str,

View file

@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import StreamToken from synapse.types import StreamToken
@ -69,6 +70,7 @@ class EventSources:
) )
return token return token
@trace
async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the current token for a given room to be used to paginate """Get the current token for a given room to be used to paginate
events. events.

View file

@ -834,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
) -> Optional[VT]: ) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics) return self._lru_cache.get(key, update_metrics=update_metrics)
async def get_external(
self,
key: KT,
default: Optional[T] = None,
update_metrics: bool = True,
) -> Optional[VT]:
# This method should fetch from any configured external cache, in this case noop.
return None
def get_local(
self, key: KT, default: Optional[T] = None, update_metrics: bool = True
) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics)
async def set(self, key: KT, value: VT) -> None: async def set(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value) self._lru_cache.set(key, value)
def set_local(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value)
async def invalidate(self, key: KT) -> None: async def invalidate(self, key: KT) -> None:
# This method should invalidate any external cache and then invalidate the LruCache. # This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key) return self._lru_cache.invalidate(key)

View file

@ -21,7 +21,7 @@ from typing import Any, DefaultDict, Iterator, List, Set
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
make_deferred_yieldable, make_deferred_yieldable,
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class FederationRateLimiter: class FederationRateLimiter:
def __init__(self, clock: Clock, config: FederationRateLimitConfig): def __init__(self, clock: Clock, config: FederationRatelimitSettings):
def new_limiter() -> "_PerHostRatelimiter": def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(clock=clock, config=config) return _PerHostRatelimiter(clock=clock, config=config)
@ -63,7 +63,7 @@ class FederationRateLimiter:
class _PerHostRatelimiter: class _PerHostRatelimiter:
def __init__(self, clock: Clock, config: FederationRateLimitConfig): def __init__(self, clock: Clock, config: FederationRatelimitSettings):
""" """
Args: Args:
clock clock

View file

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -51,6 +52,7 @@ MEMBERSHIP_PRIORITY = (
_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "") _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
@trace
async def filter_events_for_client( async def filter_events_for_client(
storage: StorageControllers, storage: StorageControllers,
user_id: str, user_id: str,

View file

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity) self.assertTrue(complexity > 0, complexity)
@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23) self.assertEqual(complexity, 1.23)

View file

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
@ -256,7 +255,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier, RoomVersions.V7.identifier,
), ),
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as # Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that # part of the stripped room state, as the knocking homeserver already has that
@ -294,7 +293,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id), % (room_id, signed_knock_event.event_id),
signed_knock_event_json, signed_knock_event_json,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
# Check that we got the stripped room state in return # Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"] room_state_events = channel.json_body["knock_state_events"]

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from typing import Any, Dict from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
access_token=self.token, access_token=self.token,
) )
self.assertEqual(req.code, HTTPStatus.OK, req) self.assertEqual(req.code, 200, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None: def test_global_account_data_deleted_upon_deactivation(self) -> None:
""" """

View file

@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config return config
def prepare( def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass")
self.denied_user_id = self.register_user("denied", "pass") self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass") self.denied_access_token = self.login("denied", "pass")
return hs
def test_denied_without_publication_permission(self) -> None: def test_denied_without_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def prepare( def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler() self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
return hs
def test_disabling_room_list(self) -> None: def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True

View file

@ -14,6 +14,7 @@
import logging import logging
from typing import cast from typing import cast
from unittest import TestCase from unittest import TestCase
from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -22,6 +23,7 @@ from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseErro
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.federation.federation_client import SendJoinResult
from synapse.logging.context import LoggingContext, run_in_background from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
@ -30,7 +32,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import event_injection from tests.test_utils import event_injection, make_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -280,13 +282,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we poke this directly into _process_received_pdu, to avoid the # we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event. # federation handler wanting to backfill the fake event.
state_handler = self.hs.get_state_handler()
context = self.get_success(
state_handler.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)
self.get_success( self.get_success(
federation_event_handler._process_received_pdu( federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, self.OTHER_SERVER_NAME,
event, event,
state_ids={ context,
(e.type, e.state_key): e.event_id for e in current_state
},
) )
) )
@ -448,3 +458,121 @@ class EventFromPduTestCase(TestCase):
}, },
RoomVersions.V6, RoomVersions.V6,
) )
class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
def test_failed_partial_join_is_clean(self) -> None:
"""
Tests that, when failing to partial-join a room, we don't get stuck with
a partial-state flag on a room.
"""
fed_handler = self.hs.get_federation_handler()
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
membership_event = make_event_from_dict(
{
"room_id": room_id,
"type": "m.room.member",
"sender": "@alice:test",
"state_key": "@alice:test",
"content": {"membership": "join"},
},
RoomVersions.V10,
)
mock_make_membership_event = Mock(
return_value=make_awaitable(
(
"example.com",
membership_event,
RoomVersions.V10,
)
)
)
EVENT_CREATE = make_event_from_dict(
{
"room_id": room_id,
"type": "m.room.create",
"sender": "@kristina:example.com",
"state_key": "",
"depth": 0,
"content": {"creator": "@kristina:example.com", "room_version": "10"},
"auth_events": [],
"origin_server_ts": 1,
},
room_version=RoomVersions.V10,
)
EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
{
"room_id": room_id,
"type": "m.room.member",
"sender": "@kristina:example.com",
"state_key": "@kristina:example.com",
"content": {"membership": "join"},
"depth": 1,
"prev_events": [EVENT_CREATE.event_id],
"auth_events": [EVENT_CREATE.event_id],
"origin_server_ts": 1,
},
room_version=RoomVersions.V10,
)
EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
{
"room_id": room_id,
"type": "m.room.member",
"sender": "@kristina:example.com",
"state_key": "@alice:test",
"content": {"membership": "invite"},
"depth": 2,
"prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
"auth_events": [
EVENT_CREATE.event_id,
EVENT_CREATOR_MEMBERSHIP.event_id,
],
"origin_server_ts": 1,
},
room_version=RoomVersions.V10,
)
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
membership_event,
"example.com",
state=[
EVENT_CREATE,
EVENT_CREATOR_MEMBERSHIP,
EVENT_INVITATION_MEMBERSHIP,
],
auth_chain=[
EVENT_CREATE,
EVENT_CREATOR_MEMBERSHIP,
EVENT_INVITATION_MEMBERSHIP,
],
partial_state=True,
servers_in_room=["example.com"],
)
)
)
with patch.object(
fed_client, "make_membership_event", mock_make_membership_event
), patch.object(fed_client, "send_join", mock_send_join):
# Join and check that our join event is rejected
# (The join event is rejected because it doesn't have any signatures)
join_exc = self.get_failure(
fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
SynapseError,
)
self.assertIn("Join event was rejected", str(join_exc))
store = self.hs.get_datastores().main
# Check that we don't have a left-over partial_state entry.
self.assertFalse(
self.get_success(store.is_partial_state_room(room_id)),
f"Stale partial-stated room flag left over for {room_id} after a"
f" failed do_invite_join!",
)

View file

@ -314,4 +314,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", path, content={}, access_token=self.access_token "POST", path, content={}, access_token=self.access_token
) )
self.assertEqual(int(channel.result["code"]), 403) self.assertEqual(channel.code, 403)

View file

@ -15,6 +15,8 @@
from copy import deepcopy from copy import deepcopy
from typing import List from typing import List
from parameterized import parameterized
from synapse.api.constants import EduTypes, ReceiptTypes from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.types import JsonDict from synapse.types import JsonDict
@ -25,13 +27,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt self.event_source = hs.get_event_sources().sources.receipt
def test_filters_out_private_receipt(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_private_receipt(self, receipt_type: str) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1435641916114394fHBLK:matrix.org": { "$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
} }
@ -45,13 +50,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[], [],
) )
def test_filters_out_private_receipt_and_ignores_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_private_receipt_and_ignores_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -84,13 +94,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$14356419edgd14394fHBLK:matrix.org": { "$14356419edgd14394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -125,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_handles_empty_event(self): def test_handles_empty_event(self) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
@ -160,13 +175,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$14356419edgd14394fHBLK:matrix.org": { "$14356419edgd14394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -207,7 +227,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_handles_string_data(self): def test_handles_string_data(self) -> None:
""" """
Tests that an invalid shape for read-receipts is handled. Tests that an invalid shape for read-receipts is handled.
Context: https://github.com/matrix-org/synapse/issues/10603 Context: https://github.com/matrix-org/synapse/issues/10603
@ -242,13 +262,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_leaves_our_private_and_their_public(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@me:server.org": { "@me:server.org": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -273,7 +296,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@me:server.org": { "@me:server.org": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -296,13 +319,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_we_do_not_mutate(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_we_do_not_mutate(self, receipt_type: str) -> None:
"""Ensure the input values are not modified.""" """Ensure the input values are not modified."""
events = [ events = [
{ {
"content": { "content": {
"$1435641916114394fHBLK:matrix.org": { "$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
} }
@ -320,7 +346,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def _test_filters_private( def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict] self, events: List[JsonDict], expected_output: List[JsonDict]
): ) -> None:
"""Tests that the _filter_out_private returns the expected output""" """Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private_receipts( filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org" events, "@me:server.org"

View file

@ -1,4 +1,3 @@
from http import HTTPStatus
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -260,7 +259,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
f"/_matrix/client/v3/rooms/{self.room_id}/join", f"/_matrix/client/v3/rooms/{self.room_id}/join",
access_token=self.bob_token, access_token=self.bob_token,
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
# wait for join to arrive over replication # wait for join to arrive over replication
self.replicate() self.replicate()

View file

@ -15,7 +15,6 @@
import inspect import inspect
import itertools import itertools
import logging import logging
from http import HTTPStatus
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -78,7 +77,7 @@ def test_disconnect(
if expect_cancellation: if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED expected_code = HTTP_STATUS_REQUEST_CANCELLED
else: else:
expected_code = HTTPStatus.OK expected_code = 200
request = channel.request request = channel.request
if channel.is_finished(): if channel.is_finished():

View file

@ -16,6 +16,7 @@ from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EduTypes, EventTypes from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState from synapse.handlers.presence import UserPresenceState
@ -532,6 +533,34 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone") self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"]) self.assertIsNone(res["avatar_url"])
def test_update_room_membership_remote_join(self):
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
self.hs.get_room_member_handler()._remote_join = mocked_remote_join
fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to
# be persisted and the module API method to raise that.
self.get_failure(
defer.ensureDeferred(
self.module_api.update_room_membership(
sender=f"@user:{self.module_api.server_name}",
target=f"@user:{self.module_api.server_name}",
room_id=f"!nonexistent:{fake_remote_host}",
new_membership="join",
remote_room_hosts=[fake_remote_host],
)
),
NotFoundError,
)
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
def test_get_room_state(self): def test_get_room_state(self):
"""Tests that a module can retrieve the state of a room through the module API.""" """Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme") user_id = self.register_user("peter", "hackme")
@ -635,6 +664,76 @@ class ModuleApiTestCase(HomeserverTestCase):
[{"set_tweak": "sound", "value": "default"}] [{"set_tweak": "sound", "value": "default"}]
) )
def test_lookup_room_alias(self) -> None:
"""Test that modules can resolve a room alias to a room ID."""
password = "password"
user_id = self.register_user("user", password)
access_token = self.login(user_id, password)
room_alias = "my-alias"
reference_room_id = self.helper.create_room_as(
tok=access_token, extra_content={"room_alias_name": room_alias}
)
self.assertIsNotNone(reference_room_id)
(room_id, _) = self.get_success(
self.module_api.lookup_room_alias(
f"#{room_alias}:{self.module_api.server_name}"
)
)
self.assertEqual(room_id, reference_room_id)
def test_create_room(self) -> None:
"""Test that modules can create a room."""
# First test user validation (i.e. user is local).
self.get_failure(
self.module_api.create_room(
user_id=f"@user:{self.module_api.server_name}abc",
config={},
ratelimit=False,
),
RuntimeError,
)
# Now do the happy path.
user_id = self.register_user("user", "password")
access_token = self.login(user_id, "password")
room_id, room_alias = self.get_success(
self.module_api.create_room(
user_id=user_id, config={"room_alias_name": "foo-bar"}, ratelimit=False
)
)
# Check room creator.
channel = self.make_request(
"GET",
f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias.
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias.
room_id, room_alias = self.get_success(
self.module_api.create_room(user_id=user_id, config={}, ratelimit=False)
)
# Check room creator.
channel = self.make_request(
"GET",
f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias.
self.assertIsNone(room_alias)
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup""" """For testing ModuleApi functionality in a multi-worker setup"""

View file

@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self) -> None: def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False) channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys()) {"server_version", "python_version"}, set(channel.json_body.keys())
) )
@ -139,7 +139,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
) )
# Should be successful # Should be successful
self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(200, channel.code)
# Quarantine the media # Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@ -152,7 +152,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.pump(1.0) self.pump(1.0)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Attempt to access the media # Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id) self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@ -209,7 +209,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.pump(1.0) self.pump(1.0)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
) )
@ -251,7 +251,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.pump(1.0) self.pump(1.0)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
) )
@ -285,7 +285,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok) channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0) self.pump(1.0)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Quarantine all media by this user # Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@ -297,7 +297,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.pump(1.0) self.pump(1.0)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
) )
@ -318,10 +318,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined # Shouldn't be quarantined
self.assertEqual( self.assertEqual(
HTTPStatus.OK, 200,
channel.code, channel.code,
msg=( msg=(
"Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2 % server_and_media_id_2
), ),
) )
@ -350,7 +350,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self) -> None: def test_purge_history(self) -> None:
""" """
Simple test of purge history API. Simple test of purge history API.
Test only that is is possible to call, get status HTTPStatus.OK and purge_id. Test only that is is possible to call, get status 200 and purge_id.
""" """
channel = self.make_request( channel = self.make_request(
@ -360,7 +360,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body) self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"] purge_id = channel.json_body["purge_id"]
@ -371,5 +371,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"]) self.assertEqual("complete", channel.json_body["status"])

View file

@ -125,7 +125,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status", "/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running. # Background updates should be enabled, but none should be running.
self.assertDictEqual( self.assertDictEqual(
@ -147,7 +147,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status", "/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running. # Background updates should be enabled, and one should be running.
self.assertDictEqual( self.assertDictEqual(
@ -181,7 +181,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled", "/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True}) self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates # Disable the BG updates
@ -191,7 +191,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False}, content={"enabled": False},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False}) self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in # Advance a bit and get the current status, note this will finish the in
@ -204,7 +204,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status", "/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual( self.assertDictEqual(
channel.json_body, channel.json_body,
{ {
@ -231,7 +231,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status", "/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response. # There should be no change from the previous /status response.
self.assertDictEqual( self.assertDictEqual(
@ -259,7 +259,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True}, content={"enabled": True},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True}) self.assertDictEqual(channel.json_body, {"enabled": True})
@ -270,7 +270,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status", "/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress. # Background updates should be enabled and making progress.
self.assertDictEqual( self.assertDictEqual(
@ -325,7 +325,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# test that each background update is waiting now # test that each background update is waiting now
for update in updates: for update in updates:

View file

@ -122,7 +122,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
def test_unknown_device(self) -> None: def test_unknown_device(self) -> None:
""" """
Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or 200.
""" """
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user self.other_user
@ -143,7 +143,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
channel = self.make_request( channel = self.make_request(
"DELETE", "DELETE",
@ -151,8 +151,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
# Delete unknown device returns status HTTPStatus.OK # Delete unknown device returns status 200
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self) -> None: def test_update_device_too_long_display_name(self) -> None:
""" """
@ -189,12 +189,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"]) self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self) -> None: def test_update_no_display_name(self) -> None:
""" """
Tests that a update for a device without JSON returns a HTTPStatus.OK Tests that a update for a device without JSON returns a 200
""" """
# Set iniital display name. # Set iniital display name.
update = {"display_name": "new display"} update = {"display_name": "new display"}
@ -210,7 +210,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated. # Ensure the display name was not updated.
channel = self.make_request( channel = self.make_request(
@ -219,7 +219,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"]) self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self) -> None: def test_update_display_name(self) -> None:
@ -234,7 +234,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"}, content={"display_name": "new displayname"},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name # Check new display_name
channel = self.make_request( channel = self.make_request(
@ -243,7 +243,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"]) self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self) -> None: def test_get_device(self) -> None:
@ -256,7 +256,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available # Check that all fields are available
self.assertIn("user_id", channel.json_body) self.assertIn("user_id", channel.json_body)
@ -281,7 +281,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased # Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user)) res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@ -379,7 +379,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"])) self.assertEqual(0, len(channel.json_body["devices"]))
@ -399,7 +399,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@ -494,7 +494,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
def test_unknown_devices(self) -> None: def test_unknown_devices(self) -> None:
""" """
Tests that a remove of a device that does not exist returns HTTPStatus.OK. Tests that a remove of a device that does not exist returns 200.
""" """
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -503,8 +503,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]}, content={"devices": ["unknown_device1", "unknown_device2"]},
) )
# Delete unknown devices returns status HTTPStatus.OK # Delete unknown devices returns status 200
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_delete_devices(self) -> None: def test_delete_devices(self) -> None:
""" """
@ -533,7 +533,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids}, content={"devices": device_ids},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user)) res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res)) self.assertEqual(0, len(res))

View file

@ -117,7 +117,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -134,7 +134,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
@ -151,7 +151,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -168,7 +168,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertEqual(len(channel.json_body["event_reports"]), 10)
@ -185,7 +185,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10) self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -205,7 +205,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10) self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -225,7 +225,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5) self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -247,7 +247,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1 report = 1
@ -265,7 +265,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1 report = 1
@ -344,7 +344,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -357,7 +357,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -370,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19) self.assertEqual(channel.json_body["next_token"], 19)
@ -384,7 +384,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -400,7 +400,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"}, {"score": -100, "reason": "this makes me sad"},
access_token=user_tok, access_token=user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters( def _create_event_and_report_without_parameters(
self, room_id: str, user_tok: str self, room_id: str, user_tok: str
@ -415,7 +415,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{}, {},
access_token=user_tok, access_token=user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: List[JsonDict]) -> None: def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in an event report""" """Checks that all attributes are present in an event report"""
@ -502,7 +502,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
def test_invalid_report_id(self) -> None: def test_invalid_report_id(self) -> None:
@ -594,7 +594,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"}, {"score": -100, "reason": "this makes me sad"},
access_token=user_tok, access_token=user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: JsonDict) -> None: def _check_fields(self, content: JsonDict) -> None:
"""Checks that all attributes are present in a event report""" """Checks that all attributes are present in a event report"""

View file

@ -142,7 +142,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 5) self.assertEqual(len(channel.json_body["destinations"]), 5)
self.assertEqual(channel.json_body["next_token"], "5") self.assertEqual(channel.json_body["next_token"], "5")
@ -160,7 +160,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 15) self.assertEqual(len(channel.json_body["destinations"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["destinations"]), 10) self.assertEqual(len(channel.json_body["destinations"]), 10)
@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -211,7 +211,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -224,7 +224,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 19) self.assertEqual(len(channel.json_body["destinations"]), 19)
self.assertEqual(channel.json_body["next_token"], "19") self.assertEqual(channel.json_body["next_token"], "19")
@ -238,7 +238,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 1) self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -255,7 +255,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_destinations, len(channel.json_body["destinations"])) self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
self.assertEqual(number_destinations, channel.json_body["total"]) self.assertEqual(number_destinations, channel.json_body["total"])
@ -290,7 +290,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_destination_list)) self.assertEqual(channel.json_body["total"], len(expected_destination_list))
returned_order = [ returned_order = [
@ -376,7 +376,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that destinations were returned # Check that destinations were returned
self.assertTrue("destinations" in channel.json_body) self.assertTrue("destinations" in channel.json_body)
@ -418,7 +418,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"]) self.assertEqual("sub0.example.com", channel.json_body["destination"])
# Check that all fields are available # Check that all fields are available
@ -435,7 +435,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"]) self.assertEqual("sub0.example.com", channel.json_body["destination"])
self.assertEqual(0, channel.json_body["retry_last_ts"]) self.assertEqual(0, channel.json_body["retry_last_ts"])
self.assertEqual(0, channel.json_body["retry_interval"]) self.assertEqual(0, channel.json_body["retry_interval"])
@ -452,7 +452,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
retry_timings = self.get_success( retry_timings = self.get_success(
self.store.get_destination_retry_timings("sub0.example.com") self.store.get_destination_retry_timings("sub0.example.com")
@ -619,7 +619,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 3) self.assertEqual(len(channel.json_body["rooms"]), 3)
self.assertEqual(channel.json_body["next_token"], "3") self.assertEqual(channel.json_body["next_token"], "3")
@ -637,7 +637,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 5) self.assertEqual(len(channel.json_body["rooms"]), 5)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -655,7 +655,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(channel.json_body["next_token"], "8") self.assertEqual(channel.json_body["next_token"], "8")
self.assertEqual(len(channel.json_body["rooms"]), 5) self.assertEqual(len(channel.json_body["rooms"]), 5)
@ -673,7 +673,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body) self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body)
self.assertEqual(channel_asc.json_body["total"], number_rooms) self.assertEqual(channel_asc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"])) self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
self._check_fields(channel_asc.json_body["rooms"]) self._check_fields(channel_asc.json_body["rooms"])
@ -685,7 +685,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body) self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body)
self.assertEqual(channel_desc.json_body["total"], number_rooms) self.assertEqual(channel_desc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"])) self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
self._check_fields(channel_desc.json_body["rooms"]) self._check_fields(channel_desc.json_body["rooms"])
@ -711,7 +711,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -724,7 +724,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -737,7 +737,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 4) self.assertEqual(len(channel.json_body["rooms"]), 4)
self.assertEqual(channel.json_body["next_token"], "4") self.assertEqual(channel.json_body["next_token"], "4")
@ -751,7 +751,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 1) self.assertEqual(len(channel.json_body["rooms"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -767,7 +767,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"]) self._check_fields(channel.json_body["rooms"])

View file

@ -131,7 +131,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource, upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=HTTPStatus.OK, expect_code=200,
) )
# Extract media ID from the response # Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@ -151,11 +151,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful # Should be successful
self.assertEqual( self.assertEqual(
HTTPStatus.OK, 200,
channel.code, channel.code,
msg=( msg=(
"Expected to receive a HTTPStatus.OK on accessing media: %s" "Expected to receive a 200 on accessing media: %s" % server_and_media_id
% server_and_media_id
), ),
) )
@ -172,7 +171,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
media_id, media_id,
@ -388,7 +387,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms), self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
media_id, media_id,
@ -413,7 +412,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms), self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id) self._access_media(server_and_media_id)
@ -425,7 +424,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms), self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
server_and_media_id.split("/")[1], server_and_media_id.split("/")[1],
@ -449,7 +448,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id) self._access_media(server_and_media_id)
@ -460,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
server_and_media_id.split("/")[1], server_and_media_id.split("/")[1],
@ -485,7 +484,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
channel = self.make_request( channel = self.make_request(
@ -493,7 +492,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id) self._access_media(server_and_media_id)
@ -504,7 +503,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
server_and_media_id.split("/")[1], server_and_media_id.split("/")[1],
@ -530,7 +529,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"url": "mxc://%s" % (server_and_media_id,)}, content={"url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
channel = self.make_request( channel = self.make_request(
@ -538,7 +537,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id) self._access_media(server_and_media_id)
@ -549,7 +548,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual( self.assertEqual(
server_and_media_id.split("/")[1], server_and_media_id.split("/")[1],
@ -569,7 +568,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
upload_resource, upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=HTTPStatus.OK, expect_code=200,
) )
# Extract media ID from the response # Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@ -602,10 +601,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success: if expect_success:
self.assertEqual( self.assertEqual(
HTTPStatus.OK, 200,
channel.code, channel.code,
msg=( msg=(
"Expected to receive a HTTPStatus.OK on accessing media: %s" "Expected to receive a 200 on accessing media: %s"
% server_and_media_id % server_and_media_id
), ),
) )
@ -648,7 +647,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource, upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=HTTPStatus.OK, expect_code=200,
) )
# Extract media ID from the response # Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@ -712,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
@ -726,7 +725,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
@ -753,7 +752,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
# verify that is not in quarantine # verify that is not in quarantine
@ -785,7 +784,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource, upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=HTTPStatus.OK, expect_code=200,
) )
# Extract media ID from the response # Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@ -845,7 +844,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
@ -859,7 +858,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))

View file

@ -105,7 +105,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16) self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -129,7 +129,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
@ -150,7 +150,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16) self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -207,7 +207,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data, data,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) self.assertEqual(200, channel1.code, msg=channel1.json_body)
channel2 = self.make_request( channel2 = self.make_request(
"POST", "POST",
@ -251,7 +251,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0}, {"uses_allowed": 0},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer # Should fail with negative integer
@ -321,7 +321,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64}, {"length": 64},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 64) self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0 # Should fail with 0
@ -439,7 +439,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1}, {"uses_allowed": 1},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -450,7 +450,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0}, {"uses_allowed": 0},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -461,7 +461,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None}, {"uses_allowed": None},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -506,7 +506,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time}, {"expiry_time": new_expiry_time},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
@ -517,7 +517,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None}, {"expiry_time": None},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
@ -568,7 +568,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
@ -655,7 +655,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# GETTING ONE # GETTING ONE
@ -716,7 +716,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["expiry_time"])
@ -762,7 +762,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 1) self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0] token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token) self.assertEqual(token_info["token"], token)
@ -816,7 +816,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 2) self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0] token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1] token_info_2 = channel.json_body["registration_tokens"][1]

View file

@ -94,7 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_room_is_not_valid(self) -> None: def test_room_is_not_valid(self) -> None:
""" """
@ -127,7 +127,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body) self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body) self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -202,7 +202,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -233,7 +233,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -265,7 +265,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -296,7 +296,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
) )
# The room is now blocked. # The room is now blocked.
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id) self._is_blocked(room_id)
def test_shutdown_room_consent(self) -> None: def test_shutdown_room_consent(self) -> None:
@ -337,7 +337,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body) self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -366,7 +366,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
{"history_visibility": "world_readable"}, {"history_visibility": "world_readable"},
access_token=self.other_user_tok, access_token=self.other_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged # Test that room is not purged
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -383,7 +383,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body) self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body)
@ -522,7 +522,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -533,7 +533,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
@ -574,7 +574,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -639,7 +639,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id1 = channel.json_body["delete_id"] delete_id1 = channel.json_body["delete_id"]
@ -654,7 +654,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id2 = channel.json_body["delete_id"] delete_id2 = channel.json_body["delete_id"]
@ -665,7 +665,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, len(channel.json_body["results"])) self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"]) self.assertEqual("complete", channel.json_body["results"][1]["status"])
@ -682,7 +682,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
@ -733,7 +733,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# get result of first call # get result of first call
first_channel.await_result() first_channel.await_result()
self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
self.assertIn("delete_id", first_channel.json_body) self.assertIn("delete_id", first_channel.json_body)
# check status after finish the task # check status after finish the task
@ -764,7 +764,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -795,7 +795,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -827,7 +827,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -876,7 +876,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -887,7 +887,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id, self.url_status_by_room_id,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room # Test that member has moved to new room
@ -914,7 +914,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
content={"history_visibility": "world_readable"}, content={"history_visibility": "world_readable"},
access_token=self.other_user_tok, access_token=self.other_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged # Test that room is not purged
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -931,7 +931,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body) self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"] delete_id = channel.json_body["delete_id"]
@ -942,7 +942,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id, self.url_status_by_room_id,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room # Test that member has moved to new room
@ -1026,9 +1026,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id, self.url_status_by_room_id,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual( self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body)
HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
)
self.assertEqual(1, len(channel_room_id.json_body["results"])) self.assertEqual(1, len(channel_room_id.json_body["results"]))
self.assertEqual( self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"] delete_id, channel_room_id.json_body["results"][0]["delete_id"]
@ -1041,7 +1039,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual( self.assertEqual(
HTTPStatus.OK, 200,
channel_delete_id.code, channel_delete_id.code,
msg=channel_delete_id.json_body, msg=channel_delete_id.json_body,
) )
@ -1100,7 +1098,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
) )
# Check request completed successfully # Check request completed successfully
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key # Check that response json body contains a "rooms" key
self.assertTrue( self.assertTrue(
@ -1186,7 +1184,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body) self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]: for r in channel.json_body["rooms"]:
@ -1226,7 +1224,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self) -> None: def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned""" """Test the correct attributes for a room are returned"""
@ -1253,7 +1251,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id}, {"room_id": room_id},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self.helper.send_state( self.helper.send_state(
@ -1285,7 +1283,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned # Check that rooms were returned
self.assertTrue("rooms" in channel.json_body) self.assertTrue("rooms" in channel.json_body)
@ -1341,7 +1339,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned # Check that rooms were returned
self.assertTrue("rooms" in channel.json_body) self.assertTrue("rooms" in channel.json_body)
@ -1487,7 +1485,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test( def _search_test(
expected_room_id: Optional[str], expected_room_id: Optional[str],
search_term: str, search_term: str,
expected_http_code: int = HTTPStatus.OK, expected_http_code: int = 200,
) -> None: ) -> None:
"""Search for a room and check that the returned room's id is a match """Search for a room and check that the returned room's id is a match
@ -1505,7 +1503,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
if expected_http_code != HTTPStatus.OK: if expected_http_code != 200:
return return
# Check that rooms were returned # Check that rooms were returned
@ -1585,7 +1583,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name")) self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
@ -1618,7 +1616,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body) self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body) self.assertIn("name", channel.json_body)
@ -1650,7 +1648,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"]) self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room # Have another user join the room
@ -1664,7 +1662,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"]) self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room # leave room
@ -1676,7 +1674,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"]) self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self) -> None: def test_room_members(self) -> None:
@ -1707,7 +1705,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual( self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@ -1720,7 +1718,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual( self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
@ -1738,7 +1736,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body) self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that # testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got # the create_room already does the right thing, so no need to verify that we got
@ -1755,7 +1753,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id}, {"room_id": room_id},
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self.helper.send_state( self.helper.send_state(
@ -1772,6 +1770,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
tok=admin_user_tok, tok=admin_user_tok,
) )
def test_get_joined_members_after_leave_room(self) -> None:
"""Test that requesting room members after leaving the room raises a 403 error."""
# create the room
user = self.register_user("foo", "pass")
user_tok = self.login("foo", "pass")
room_id = self.helper.create_room_as(user, tok=user_tok)
self.helper.leave(room_id, user, tok=user_tok)
# delete the rooms and get joined roomed membership
url = f"/_matrix/client/r0/rooms/{room_id}/joined_members"
channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok)
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
class JoinAliasRoomTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
@ -1873,7 +1886,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"]) self.assertEqual(
"Can't join remote room because no servers that are in the room have been provided.",
channel.json_body["error"],
)
def test_room_is_not_valid(self) -> None: def test_room_is_not_valid(self) -> None:
""" """
@ -1906,7 +1922,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"]) self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room # Validate if user is a member of the room
@ -1916,7 +1932,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms", "/_matrix/client/r0/joined_rooms",
access_token=self.second_tok, access_token=self.second_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self) -> None: def test_join_private_room_if_not_member(self) -> None:
@ -1964,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms", "/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room. # Join user to room.
@ -1977,7 +1993,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content={"user_id": self.second_user_id}, content={"user_id": self.second_user_id},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"]) self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room # Validate if user is a member of the room
@ -1987,7 +2003,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms", "/_matrix/client/r0/joined_rooms",
access_token=self.second_tok, access_token=self.second_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self) -> None: def test_join_private_room_if_owner(self) -> None:
@ -2007,7 +2023,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"]) self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room # Validate if user is a member of the room
@ -2017,7 +2033,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms", "/_matrix/client/r0/joined_rooms",
access_token=self.second_tok, access_token=self.second_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self) -> None: def test_context_as_non_admin(self) -> None:
@ -2081,7 +2097,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]), % (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
channel.json_body["event"]["event_id"], events[midway]["event_id"] channel.json_body["event"]["event_id"], events[midway]["event_id"]
) )
@ -2140,7 +2156,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user. # Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@ -2167,7 +2183,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an # Now we test that we can join the room (we should have received an
# invite) and can ban a user. # invite) and can ban a user.
@ -2193,7 +2209,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user. # Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok) self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@ -2336,7 +2352,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True}, content={"block": True},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"]) self.assertTrue(channel.json_body["block"])
self._is_blocked(room_id, expect=True) self._is_blocked(room_id, expect=True)
@ -2360,7 +2376,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True}, content={"block": True},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"]) self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True) self._is_blocked(self.room_id, expect=True)
@ -2376,7 +2392,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False}, content={"block": False},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"]) self.assertFalse(channel.json_body["block"])
self._is_blocked(room_id, expect=False) self._is_blocked(room_id, expect=False)
@ -2400,7 +2416,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False}, content={"block": False},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"]) self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False) self._is_blocked(self.room_id, expect=False)
@ -2415,7 +2431,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id, self.url % room_id,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"]) self.assertTrue(channel.json_body["block"])
self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertEqual(self.other_user, channel.json_body["user_id"])
@ -2439,7 +2455,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id, self.url % room_id,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"]) self.assertFalse(channel.json_body["block"])
self.assertNotIn("user_id", channel.json_body) self.assertNotIn("user_id", channel.json_body)

View file

@ -197,7 +197,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"}, "content": {"msgtype": "m.text", "body": "test msg one"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite # user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@ -226,7 +226,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"}, "content": {"msgtype": "m.text", "body": "test msg two"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has no new invites or memberships # user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1) self._check_invite_and_join_status(self.other_user, 0, 1)
@ -260,7 +260,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"}, "content": {"msgtype": "m.text", "body": "test msg one"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite # user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@ -301,7 +301,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"}, "content": {"msgtype": "m.text", "body": "test msg two"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite # user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@ -341,7 +341,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"}, "content": {"msgtype": "m.text", "body": "test msg one"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite # user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@ -388,7 +388,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"}, "content": {"msgtype": "m.text", "body": "test msg two"},
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite # user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@ -538,7 +538,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token "GET", "/_matrix/client/r0/sync", access_token=token
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
# Get the messages # Get the messages
room = channel.json_body["rooms"]["join"][room_id] room = channel.json_body["rooms"]["join"][room_id]

View file

@ -204,7 +204,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10) self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
@ -222,7 +222,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15) self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -240,7 +240,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10) self.assertEqual(len(channel.json_body["users"]), 10)
@ -262,7 +262,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users) self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -275,7 +275,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users) self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -288,7 +288,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19) self.assertEqual(channel.json_body["next_token"], 19)
@ -301,7 +301,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1) self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -318,7 +318,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"])) self.assertEqual(0, len(channel.json_body["users"]))
@ -415,7 +415,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3) self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media # filter media starting at `ts1` after creating first media
@ -425,7 +425,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,), self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0) self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3) self._create_media(self.other_user_tok, 3)
@ -440,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3) self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier # filter media until `ts2` and earlier
@ -449,7 +449,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,), self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6) self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self) -> None: def test_search_term(self) -> None:
@ -461,7 +461,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id` # filter user 1 and 10-19 by `user_id`
@ -470,7 +470,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1", self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 11) self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname` # filter on this user in `displayname`
@ -479,7 +479,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10", self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1) self.assertEqual(channel.json_body["total"], 1)
@ -489,7 +489,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar", self.url + "?search_term=foobar",
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0) self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
@ -515,7 +515,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media): for _ in range(number_media):
# Upload some media into the room # Upload some media into the room
self.helper.upload_media( self.helper.upload_media(
upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK upload_resource, SMALL_PNG, tok=user_token, expect_code=200
) )
def _check_fields(self, content: List[JsonDict]) -> None: def _check_fields(self, content: List[JsonDict]) -> None:
@ -549,7 +549,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"), url.encode("ascii"),
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list)) self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]] returned_order = [row["user_id"] for row in channel.json_body["users"]]

View file

@ -169,7 +169,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self) -> None: def test_nonce_reuse(self) -> None:
@ -192,7 +192,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it # Now, try and reuse it
@ -323,11 +323,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"]) self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname") channel = self.make_request("GET", "/profile/@bob1:test/displayname")
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"]) self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None # displayname is None
@ -347,11 +347,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"]) self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname") channel = self.make_request("GET", "/profile/@bob2:test/displayname")
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"]) self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty # displayname is empty
@ -371,7 +371,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"]) self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname") channel = self.make_request("GET", "/profile/@bob3:test/displayname")
@ -394,11 +394,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"]) self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname") channel = self.make_request("GET", "/profile/@bob4:test/displayname")
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"]) self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config( @override_config(
@ -442,7 +442,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
@ -494,7 +494,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"]) self.assertEqual(3, channel.json_body["total"])
@ -508,7 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str], expected_user_id: Optional[str],
search_term: str, search_term: str,
search_field: Optional[str] = "name", search_field: Optional[str] = "name",
expected_http_code: Optional[int] = HTTPStatus.OK, expected_http_code: Optional[int] = 200,
) -> None: ) -> None:
"""Search for a user and check that the returned user's id is a match """Search for a user and check that the returned user's id is a match
@ -530,7 +530,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
if expected_http_code != HTTPStatus.OK: if expected_http_code != 200:
return return
# Check that users were returned # Check that users were returned
@ -659,7 +659,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5") self.assertEqual(channel.json_body["next_token"], "5")
@ -680,7 +680,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15) self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -701,7 +701,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10) self.assertEqual(len(channel.json_body["users"]), 10)
@ -724,7 +724,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users) self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -737,7 +737,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users) self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -750,7 +750,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19") self.assertEqual(channel.json_body["next_token"], "19")
@ -764,7 +764,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1) self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -867,7 +867,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list)) self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]] returned_order = [row["name"] for row in channel.json_body["users"]]
@ -1017,7 +1017,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@ -1032,7 +1032,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True}, content={"erase": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1041,7 +1041,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
@ -1066,7 +1066,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"erase": True}, content={"erase": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_erased("@user:test", True) self._is_erased("@user:test", True)
def test_deactivate_user_erase_false(self) -> None: def test_deactivate_user_erase_false(self) -> None:
@ -1081,7 +1081,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@ -1096,7 +1096,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False}, content={"erase": False},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1105,7 +1105,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
@ -1135,7 +1135,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@ -1150,7 +1150,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True}, content={"erase": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1159,7 +1159,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
@ -1352,7 +1352,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
@ -1395,7 +1395,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1458,7 +1458,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1486,7 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached # before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
if channel.code != HTTPStatus.OK: if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"] channel.code, channel.result["reason"], channel.result["body"]
) )
@ -1684,7 +1684,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"}, content={"password": "hahaha"},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
def test_set_displayname(self) -> None: def test_set_displayname(self) -> None:
@ -1700,7 +1700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"}, content={"displayname": "foobar"},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual("foobar", channel.json_body["displayname"])
@ -1711,7 +1711,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual("foobar", channel.json_body["displayname"])
@ -1733,7 +1733,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted # result does not always have the same sort order, therefore it becomes sorted
@ -1759,7 +1759,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1775,7 +1775,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1791,7 +1791,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"threepids": []}, content={"threepids": []},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
@ -1818,7 +1818,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1837,7 +1837,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1859,7 +1859,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
) )
# other user has this two threepids # other user has this two threepids
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted # result does not always have the same sort order, therefore it becomes sorted
@ -1878,7 +1878,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user, url_first_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
@ -1907,7 +1907,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted # result does not always have the same sort order, therefore it becomes sorted
@ -1939,7 +1939,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -1958,7 +1958,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -1977,7 +1977,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"external_ids": []}, content={"external_ids": []},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"])) self.assertEqual(0, len(channel.json_body["external_ids"]))
@ -2006,7 +2006,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -2032,7 +2032,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
}, },
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -2075,7 +2075,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -2093,7 +2093,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual( self.assertEqual(
@ -2124,7 +2124,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@ -2139,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True}, content={"deactivated": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
@ -2158,7 +2158,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
@ -2188,7 +2188,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True}, content={"deactivated": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
@ -2204,7 +2204,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"}, content={"displayname": "Foobar"},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"]) self.assertEqual("Foobar", channel.json_body["displayname"])
@ -2237,7 +2237,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"}, content={"deactivated": False, "password": "foo"},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False) self._is_erased("@user:test", False)
@ -2271,7 +2271,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"deactivated": False}, content={"deactivated": False},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False) self._is_erased("@user:test", False)
@ -2305,7 +2305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"deactivated": False}, content={"deactivated": False},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False) self._is_erased("@user:test", False)
@ -2326,7 +2326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True}, content={"admin": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
@ -2337,7 +2337,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
@ -2354,7 +2354,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT}, content={"user_type": UserTypes.SUPPORT},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@ -2365,7 +2365,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@ -2377,7 +2377,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None}, content={"user_type": None},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"]) self.assertIsNone(channel.json_body["user_type"])
@ -2388,7 +2388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"]) self.assertIsNone(channel.json_body["user_type"])
@ -2418,7 +2418,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
@ -2440,7 +2440,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual("bob", channel.json_body["displayname"])
@ -2465,7 +2465,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"deactivated": True}, content={"deactivated": True},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self._is_erased(user_id, False) self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id) d = self.store.mark_user_erased(user_id)
@ -2549,7 +2549,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"])) self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@ -2565,7 +2565,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"])) self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@ -2581,7 +2581,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"])) self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@ -2602,7 +2602,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@ -2649,7 +2649,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
@ -2737,7 +2737,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
# Register the pusher # Register the pusher
@ -2769,7 +2769,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"]) self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]: for p in channel.json_body["pushers"]:
@ -2865,7 +2865,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
@ -2884,7 +2884,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5) self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@ -2901,7 +2901,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15) self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -2920,7 +2920,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15) self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@ -2937,7 +2937,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10) self.assertEqual(len(channel.json_body["media"]), 10)
@ -2956,7 +2956,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10) self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@ -3023,7 +3023,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media) self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -3036,7 +3036,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media) self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -3049,7 +3049,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19) self.assertEqual(channel.json_body["next_token"], 19)
@ -3063,7 +3063,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1) self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -3080,7 +3080,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"])) self.assertEqual(0, len(channel.json_body["media"]))
@ -3095,7 +3095,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"])) self.assertEqual(0, len(channel.json_body["deleted_media"]))
@ -3112,7 +3112,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"])) self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
@ -3138,7 +3138,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids) self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@ -3283,7 +3283,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK upload_resource, image_data, user_token, filename, expect_code=200
) )
# Extract media ID from the response # Extract media ID from the response
@ -3301,10 +3301,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual( self.assertEqual(
HTTPStatus.OK, 200,
channel.code, channel.code,
msg=( msg=(
f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" f"Expected to receive a 200 on accessing media: {server_and_media_id}"
), ),
) )
@ -3350,7 +3350,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list)) self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]] returned_order = [row["media_id"] for row in channel.json_body["media"]]
@ -3386,7 +3386,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok "POST", self.url, b"{}", access_token=self.admin_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"] return channel.json_body["access_token"]
def test_no_auth(self) -> None: def test_no_auth(self) -> None:
@ -3427,7 +3427,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`) # We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1) self.assertEqual(len(channel.json_body["devices"]), 1)
@ -3439,11 +3439,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token # Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work # The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@ -3453,7 +3453,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self) -> None: def test_user_logout_all(self) -> None:
"""Tests that the target user calling `/logout/all` does *not* expire """Tests that the target user calling `/logout/all` does *not* expire
@ -3464,17 +3464,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token # Logout all with the real user token
channel = self.make_request( channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok "POST", "logout/all", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work # The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't # .. but the real user's tokens shouldn't
channel = self.make_request( channel = self.make_request(
@ -3491,13 +3491,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request # Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token # Logout all with the admin user token
channel = self.make_request( channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok "POST", "logout/all", b"{}", access_token=self.admin_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work # The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@ -3507,7 +3507,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok "GET", "devices", b"{}", access_token=self.other_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config( @unittest.override_config(
{ {
@ -3635,7 +3635,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body) self.assertIn("devices", channel.json_body)
@ -3650,7 +3650,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=other_user_token, access_token=other_user_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body) self.assertIn("devices", channel.json_body)
@ -3715,7 +3715,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned) self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body) self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared). # Ensure the user is shadow-banned (and the cache was cleared).
@ -3727,7 +3727,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok "DELETE", self.url, access_token=self.admin_user_tok
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body) self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared). # Ensure the user is no longer shadow-banned (and the cache was cleared).
@ -3891,7 +3891,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"]) self.assertEqual(0, channel.json_body["burst_count"])
@ -3905,7 +3905,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)
@ -3916,7 +3916,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11}, content={"messages_per_second": 10, "burst_count": 11},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"]) self.assertEqual(11, channel.json_body["burst_count"])
@ -3927,7 +3927,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21}, content={"messages_per_second": 20, "burst_count": 21},
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"]) self.assertEqual(21, channel.json_body["burst_count"])
@ -3937,7 +3937,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"]) self.assertEqual(21, channel.json_body["burst_count"])
@ -3947,7 +3947,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)
@ -3957,7 +3957,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body) self.assertNotIn("burst_count", channel.json_body)
@ -4042,7 +4042,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
self.url, self.url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual( self.assertEqual(
{"a": 1}, channel.json_body["account_data"]["global"]["m.global"] {"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
) )

View file

@ -50,18 +50,18 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
def test_username_available(self) -> None: def test_username_available(self) -> None:
""" """
The endpoint should return a HTTPStatus.OK response if the username does not exist The endpoint should return a 200 response if the username does not exist
""" """
url = "%s?username=%s" % (self.url, "allowed") url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok) channel = self.make_request("GET", url, access_token=self.admin_user_tok)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"]) self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self) -> None: def test_username_unavailable(self) -> None:
""" """
The endpoint should return a HTTPStatus.OK response if the username does not exist The endpoint should return a 200 response if the username does not exist
""" """
url = "%s?username=%s" % (self.url, "disallowed") url = "%s?username=%s" % (self.url, "disallowed")

View file

@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"}) self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success( filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0) self.store.get_user_filter(user_localpart="apple", filter_id=0)
@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None: def test_add_filter_non_local_user(self) -> None:
@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
) )
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None: def test_get_filter(self) -> None:
@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None: def test_get_filter_non_existant(self) -> None:
@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"404") self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None: def test_get_filter_no_id(self) -> None:
@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.code, 400)

View file

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import time import time
import urllib.parse import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from urllib.parse import urlencode from urllib.parse import urlencode
@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config( @override_config(
{ {
@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config( @override_config(
{ {
@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None: def test_soft_logout(self) -> None:
@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token # we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL) channel = self.make_request(b"GET", TEST_URL)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal # log in as normal
@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"] device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout # more requests with the expired token should still return a soft-logout
self.reactor.advance(3600) self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id) self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False) self.assertEqual(channel.json_body["soft_logout"], False)
@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token b"DELETE", "devices/" + device_id, access_token=access_token
) )
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
# check it's a UI-Auth fail # check it's a UI-Auth fail
self.assertEqual( self.assertEqual(
set(channel.json_body.keys()), set(channel.json_body.keys()),
@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
content={"auth": auth}, content={"auth": auth},
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard logout this session # Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token) channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions # Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token) channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None: def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese") self.register_user("mickey", "cheese")
@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None: def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows""" """GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [ expected_flow_types = [
"m.login.cas", "m.login.cas",
@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None) channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
uri = location_headers[0] uri = location_headers[0]
# hitting that picker should give us some HTML # hitting that picker should give us some HTML
channel = self.make_request("GET", uri) channel = self.make_request("GET", uri)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class # parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8") html = channel.result["body"].decode("utf-8")
@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas", + "&idp=cas",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
cas_uri = location_headers[0] cas_uri = location_headers[0]
@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
saml_uri = location_headers[0] saml_uri = location_headers[0]
@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc", + "&idp=oidc",
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page # that should serve a confirmation page
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type") content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html")) self.assertTrue(content_type_headers[-1].startswith("text/html"))
@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test") self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None: def test_multi_sso_redirect_to_unknown(self) -> None:
@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
) )
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx") channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc") channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url) channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML. # Test that the response is HTML.
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = "" content_type_header_value = ""
for header in channel.result.get("headers", []): for header in channel.result.get("headers", []):
if header[0] == b"Content-Type": if header[0] == b"Content-Type":
@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None: def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None: def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"}) channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret") channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None: def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000}) channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self) -> None: def test_login_jwt_not_before(self) -> None:
now = int(time.time()) now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self) -> None: def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"}) channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT") self.assertEqual(channel.json_body["error"], "Invalid JWT")
@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim.""" """Test validating the issuer claim."""
# A valid issuer. # A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer. # An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an issuer. # Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_iss_no_config(self) -> None: def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration.""" """Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim.""" """Test validating the audience claim."""
# A valid audience. # A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience. # An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an audience. # Not providing an audience.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self) -> None: def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration.""" """Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_default_sub(self) -> None: def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim.""" """Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None: def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim.""" """Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"}) channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None: def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"} params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None: def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None: def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login""" """Test that the appservice bot can use /login"""
@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None: def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token""" """Test that non-as users cannot login with the as token"""
@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None: def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token""" """Test that as users cannot login with wrong as token"""
@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token b"POST", LOGIN_URL, params, access_token=self.another_service.token
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None: def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice """Test that users must provide a token when using the appservice
@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
# that should redirect to the username picker # that should redirect to the username picker
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
picker_url = location_headers[0] picker_url = location_headers[0]
@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))), ("Content-Length", str(len(content))),
], ],
) )
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0], path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)], custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
) )
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test") self.assertEqual(chan.json_body["user_id"], "@bobby:test")

View file

@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase):
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
channel = self.make_request("POST", path, content={}, access_token=access_token) channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code) self.assertEqual(channel.code, expect_code)
return channel.json_body return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token) channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id] room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"] return room_sync["timeline"]["events"]

View file

@ -70,7 +70,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname} det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@ -91,7 +91,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None: def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists self.appservice = None # no application service exists
@ -100,20 +100,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None: def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666} request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password") self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None: def test_POST_bad_username(self) -> None:
request_data = {"username": 777, "password": "monkey"} request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username") self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None: def test_POST_user_valid(self) -> None:
@ -132,7 +132,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False}) @override_config({"enable_registration": False})
@ -142,7 +142,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -153,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None: def test_POST_disabled_guest_registration(self) -> None:
@ -161,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled") self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@ -171,16 +171,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}") channel = self.make_request(b"POST", url, b"{}")
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None: def test_POST_ratelimiting(self) -> None:
@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None: def test_POST_registration_requires_token(self) -> None:
@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Request without auth to get flows and session # Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one # Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow. # flow would be a subset of another flow.
@ -248,7 +248,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"] completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@ -263,7 +263,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0 # Check the `completed` counter has been incremented and pending is 0
@ -293,21 +293,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid) # Test with non-string (invalid)
params["auth"]["token"] = 1234 params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid) # Test with unknown token (invalid)
params["auth"]["token"] = "1234" params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -361,7 +361,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session2, "session": session2,
} }
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -381,7 +381,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Check auth still fails when using token with session2 # Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -415,7 +415,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -570,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None: def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow # with the stock config, we only expect the dummy flow
@ -593,7 +593,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
self.assertCountEqual( self.assertCountEqual(
@ -625,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
def test_advertised_flows_no_msisdn_email_required(self) -> None: def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid # with the stock config, we expect all four combinations of 3pid
@ -797,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
@ -823,12 +823,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity" url = "/_synapse/admin/v1/account_validity/validity"
request_data = {"user_id": user_id} request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated # The specific endpoint doesn't matter, all we need is an authenticated
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None: def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
@ -844,12 +844,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False, "enable_renewal_emails": False,
} }
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated # The specific endpoint doesn't matter, all we need is an authenticated
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
@ -868,18 +868,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False, "enable_renewal_emails": False,
} }
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out # Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok) channel = self.make_request(b"POST", "/logout", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts) # Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions # Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok) channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@ -954,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -972,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again. # Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -992,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed. # succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds()) self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None: def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as # Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123" url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"404", channel.result) self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -1023,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail", "/_matrix/client/unstable/account_validity/send_mail",
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -1096,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail", "/_matrix/client/unstable/account_validity/send_mail",
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -1176,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True) self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None: def test_GET_token_invalid(self) -> None:
@ -1185,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False) self.assertEqual(channel.json_body["valid"], False)
@override_config( @override_config(
@ -1201,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
) )
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@ -1212,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)

View file

@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
participated, bundled_aggregations.get("current_user_participated") participated, bundled_aggregations.get("current_user_participated")
) )
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict( self.assert_dict(
{ {
"content": { "content": {
@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user2_id, "sender": self.user2_id,
"type": "m.room.test", "type": "m.room.test",
}, },
bundled_aggregations.get("latest_event"), bundled_aggregations["latest_event"],
) )
return assert_thread return assert_thread
@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(2, bundled_aggregations.get("count")) self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated")) self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict( self.assert_dict(
{ {
"content": { "content": {
@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user_id, "sender": self.user_id,
"type": "m.room.test", "type": "m.room.test",
}, },
bundled_aggregations.get("latest_event"), bundled_aggregations["latest_event"],
) )
# Check the unsigned field on the latest event. # Check the unsigned field on the latest event.
self.assert_dict( self.assert_dict(

View file

@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok "POST", self.report_path, data, access_token=self.other_user_tok
) )
self.assertEqual( self.assertEqual(response_status, channel.code, msg=channel.result["body"])
response_status, int(channel.result["code"]), msg=channel.result["body"]
)

View file

@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual( self.assertCountEqual(
[state_event["type"] for state_event in channel.json_body], [state_event["type"] for state_event in channel.json_list],
{ {
"m.room.create", "m.room.create",
"m.room.power_levels", "m.room.power_levels",
@ -2070,7 +2070,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["allow_public_rooms_without_auth"] = True config["allow_public_rooms_without_auth"] = True
config["experimental_features"] = {"msc3827_enabled": True}
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
self.url = b"/_matrix/client/r0/publicRooms" self.url = b"/_matrix/client/r0/publicRooms"
@ -2123,13 +2122,13 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
chunk, count = self.make_public_rooms_request([None]) chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1) self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None) self.assertEqual(chunk[0].get("room_type", None), None)
def test_returns_only_space_based_on_filter(self) -> None: def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"]) chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1) self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space") self.assertEqual(chunk[0].get("room_type", None), "m.space")
def test_returns_both_rooms_and_space_based_on_filter(self) -> None: def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None]) chunk, count = self.make_public_rooms_request(["m.space", None])

View file

@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin, KnockingStrippedStateEventHelperMixin,
) )
from tests.server import TimedOutException from tests.server import TimedOutException
from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase): class FilterTestCase(unittest.HomeserverTestCase):
@ -390,6 +389,12 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["experimental_features"] = {"msc2285_enabled": True}
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s" self.url = "/sync?since=%s"
self.next_batch = "s0" self.next_batch = "s0"
@ -408,15 +413,17 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user # Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_private_read_receipts(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_private_read_receipts(self, receipt_type: str) -> None:
# Send a message as the first user # Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok) res = self.helper.send(self.room_id, body="hello", tok=self.tok)
# Send a private read receipt to tell the server the first user's message was read # Send a private read receipt to tell the server the first user's message was read
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -425,8 +432,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt # Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt()) self.assertIsNone(self._get_read_receipt())
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_public_receipt_can_override_private(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_public_receipt_can_override_private(self, receipt_type: str) -> None:
""" """
Sending a public read receipt to the same event which has a private read Sending a public read receipt to the same event which has a private read
receipt should cause that receipt to become public. receipt should cause that receipt to become public.
@ -437,7 +446,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt # Send a private read receipt
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -456,8 +465,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt # Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None) self.assertNotEqual(self._get_read_receipt(), None)
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_private_receipt_cannot_override_public(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None:
""" """
Sending a private read receipt to the same event which has a public read Sending a private read receipt to the same event which has a public read
receipt should cause no change. receipt should cause no change.
@ -478,7 +489,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt # Send a private read receipt
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -590,7 +601,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
tok=self.tok, tok=self.tok,
) )
def test_unread_counts(self) -> None: @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_unread_counts(self, receipt_type: str) -> None:
"""Tests that /sync returns the right value for the unread count (MSC2654).""" """Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count. # Check that our own messages don't increase the unread count.
@ -624,7 +638,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event. # Send a read receipt to tell the server we've read the latest event.
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )
@ -700,7 +714,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5) self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
# Make sure both m.read and org.matrix.msc2285.read.private advance # Make sure both m.read and m.read.private advance
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@ -712,16 +726,22 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0) self._check_unread_count(0)
# We test for both receipt types that influence notification counts # We test for all three receipt types that influence notification counts
@parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) @parameterized.expand(
def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
]
)
def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user # Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@ -739,11 +759,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0) self._check_unread_count(0)
# Make sure neither m.read nor org.matrix.msc2285.read.private make the # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event # read receipt go up to an older event
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )
@ -948,3 +968,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
def test_incremental_sync(self) -> None:
"""Tests that activity in the room is properly filtered out of incremental
syncs.
"""
channel = self.make_request("GET", "/sync", access_token=self.tok)
self.assertEqual(channel.code, 200, channel.result)
next_batch = channel.json_body["next_batch"]
self.helper.send(self.excluded_room_id, tok=self.tok)
self.helper.send(self.included_room_id, tok=self.tok)
channel = self.make_request(
"GET",
f"/sync?since={next_batch}",
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])

View file

@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin from synapse.rest import admin
@ -154,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once() callback.assert_called_once()
@ -172,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
""" """
@ -185,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
""" """
class NastyHackException(SynapseError): class NastyHackException(SynapseError):
def error_dict(self) -> JsonDict: def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict:
""" """
This overrides SynapseError's `error_dict` to nastily inject This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response. JSON into the error response.
""" """
result = super().error_dict() result = super().error_dict(config)
result["nasty"] = "very" result["nasty"] = "very"
return result return result
@ -210,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok, access_token=self.tok,
) )
# Check the error code # Check the error code
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected # Check the JSON body has had the `nasty` key injected
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
@ -259,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"}, {"x": "x"},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"] event_id = channel.json_body["event_id"]
# ... and check that it got modified # ... and check that it got modified
@ -268,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y") self.assertEqual(ev["content"]["x"], "y")
@ -297,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"] orig_event_id = channel.json_body["event_id"]
channel = self.make_request( channel = self.make_request(
@ -314,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"] edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified # ... and check that they both got modified
@ -323,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@ -332,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY") self.assertEqual(ev["content"]["body"], "EDITED BODY")
@ -378,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"] event_id = channel.json_body["event_id"]
@ -387,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys()) self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar") self.assertEqual(channel.json_body["content"]["foo"], "bar")

View file

@ -140,7 +140,7 @@ class RestHelper:
custom_headers=custom_headers, custom_headers=custom_headers,
) )
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK: if expect_code == HTTPStatus.OK:
@ -213,11 +213,9 @@ class RestHelper:
data, data,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -312,11 +310,9 @@ class RestHelper:
data, data,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -396,11 +392,9 @@ class RestHelper:
custom_headers=custom_headers, custom_headers=custom_headers,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -449,11 +443,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content) channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -545,7 +537,7 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )

View file

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from tests import unittest from tests import unittest
@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase):
def test_health(self) -> None: def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False) channel = self.make_request("GET", "/health", shorthand=False)
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK") self.assertEqual(channel.result["body"], b"OK")

View file

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource from synapse.rest.well_known import well_known_resource
@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -57,7 +55,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) self.assertEqual(channel.code, 404)
@unittest.override_config( @unittest.override_config(
{ {
@ -71,7 +69,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -87,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{"m.server": "test:443"}, {"m.server": "test:443"},
@ -97,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) self.assertEqual(channel.code, 404)

View file

@ -25,6 +25,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List,
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
@ -121,7 +122,15 @@ class FakeChannel:
@property @property
def json_body(self) -> JsonDict: def json_body(self) -> JsonDict:
return json.loads(self.text_body) body = json.loads(self.text_body)
assert isinstance(body, dict)
return body
@property
def json_list(self) -> List[JsonDict]:
body = json.loads(self.text_body)
assert isinstance(body, list)
return body
@property @property
def text_body(self) -> str: def text_body(self) -> str:

View file

@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None): def persist_event(self, event, state=None):
"""Persist the event, with optional state""" """Persist the event, with optional state"""
context = self.get_success( context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state) self.state.compute_event_context(
event,
state_ids_before_event=state,
partial_state=None if state is None else False,
)
) )
self.get_success(self._persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.state.compute_event_context( self.state.compute_event_context(
remote_event_2, remote_event_2,
state_ids_before_event=state_before_gap, state_ids_before_event=state_before_gap,
partial_state=False,
) )
) )

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from parameterized import parameterized
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase): class ReceiptTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase):
) )
) )
def test_return_empty_with_no_data(self): def test_return_empty_with_no_data(self) -> None:
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, {}) self.assertEqual(res, {})
@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase):
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user_with_orderings( self.store.get_receipts_for_user_with_orderings(
OUR_USER_ID, OUR_USER_ID,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, {}) self.assertEqual(res, {})
@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase):
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id1, self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, None) self.assertEqual(res, None)
def test_get_receipts_for_user(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_get_receipts_for_user(self, receipt_type: str) -> None:
# Send some events into the first room # Send some events into the first room
event1_1_id = self.create_and_send_event( event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID) self.room_id1, UserID.from_string(OTHER_USER_ID)
@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
) )
) )
# Test we get the latest event when we want both private and public receipts # Test we get the latest event when we want both private and public receipts
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
) )
) )
self.assertEqual(res, {self.room_id1: event1_2_id}) self.assertEqual(res, {self.room_id1: event1_2_id})
@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the public receipt # Test we get the latest event when we want only the public receipt
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type])
) )
self.assertEqual(res, {self.room_id1: event1_2_id}) self.assertEqual(res, {self.room_id1: event1_2_id})
@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
) )
) )
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
) )
) )
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
def test_get_last_receipt_event_id_for_user(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None:
# Send some events into the first room # Send some events into the first room
event1_1_id = self.create_and_send_event( event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID) self.room_id1, UserID.from_string(OTHER_USER_ID)
@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
) )
) )
@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase):
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id1, self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [ReceiptTypes.READ, receipt_type],
) )
) )
self.assertEqual(res, event1_2_id) self.assertEqual(res, event1_2_id)
@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the private receipt # Test we get the latest event when we want only the private receipt
res = self.get_success( res = self.get_success(
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] OUR_USER_ID, self.room_id1, [receipt_type]
) )
) )
self.assertEqual(res, event1_2_id) self.assertEqual(res, event1_2_id)
@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
) )
) )
res = self.get_success( res = self.get_success(
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id2, self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [ReceiptTypes.READ, receipt_type],
) )
) )
self.assertEqual(res, event2_1_id) self.assertEqual(res, event2_1_id)

Some files were not shown because too many files have changed in this diff Show more