diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 6c3a99849..0dfab4e08 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -8,6 +8,7 @@
- Use markdown where necessary, mostly for `code blocks`.
- End with either a period (.) or an exclamation mark (!).
- Start with a capital letter.
+ - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry.
* [ ] Pull request includes a [sign off](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#sign-off)
* [ ] [Code style](https://matrix-org.github.io/synapse/latest/code_style.html) is correct
(run the [linters](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index cb72e1a23..4f5806970 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -366,6 +366,8 @@ jobs:
# Build initial Synapse image
- run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
working-directory: synapse
+ env:
+ DOCKER_BUILDKIT: 1
# Build a ready-to-run Synapse image based on the initial image above.
# This new image includes a config file, keys for signing and TLS, and
@@ -374,7 +376,8 @@ jobs:
working-directory: complement/dockerfiles
# Run Complement
- - run: go test -v -tags synapse_blacklist,msc2403 ./tests/...
+ - run: set -o pipefail && go test -v -json -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt
+ shell: bash
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement
diff --git a/.gitignore b/.gitignore
index fe137f337..3bd6b1a08 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,3 +50,7 @@ __pycache__/
# docs
book/
+
+# complement
+/complement-*
+/master.tar.gz
diff --git a/CHANGES.md b/CHANGES.md
index b8f3588ec..6da56a11f 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,10 +1,101 @@
+Synapse 1.51.0 (2022-01-25)
+===========================
+
+No significant changes since 1.51.0rc2.
+
+Synapse 1.51.0 deprecates `webclient` listeners and non-HTTP(S) `web_client_location`s. Support for these will be removed in Synapse 1.53.0, at which point Synapse will not be capable of directly serving a web client for Matrix.
+
+Synapse 1.51.0rc2 (2022-01-24)
+==============================
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806))
+
+
+Synapse 1.51.0rc1 (2022-01-21)
+==============================
+
+Features
+--------
+
+- Add `track_puppeted_user_ips` config flag to record client IP addresses against puppeted users, and include the puppeted users in monthly active user counts. ([\#11561](https://github.com/matrix-org/synapse/issues/11561), [\#11749](https://github.com/matrix-org/synapse/issues/11749), [\#11757](https://github.com/matrix-org/synapse/issues/11757))
+- Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11577](https://github.com/matrix-org/synapse/issues/11577))
+- Return an `M_FORBIDDEN` error code instead of `M_UNKNOWN` when a spam checker module prevents a user from creating a room. ([\#11672](https://github.com/matrix-org/synapse/issues/11672))
+- Add a flag to the `synapse_review_recent_signups` script to ignore and filter appservice users. ([\#11675](https://github.com/matrix-org/synapse/issues/11675), [\#11770](https://github.com/matrix-org/synapse/issues/11770))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing issue which could cause Synapse to incorrectly accept data in the unsigned field of events
+ received over federation. ([\#11530](https://github.com/matrix-org/synapse/issues/11530))
+- Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices. ([\#11587](https://github.com/matrix-org/synapse/issues/11587))
+- Fix an error that occurs whilst trying to get the federation status of a destination server that was working normally. This admin API was newly introduced in Synapse v1.49.0. ([\#11593](https://github.com/matrix-org/synapse/issues/11593))
+- Fix bundled aggregations not being included in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11612](https://github.com/matrix-org/synapse/issues/11612), [\#11659](https://github.com/matrix-org/synapse/issues/11659), [\#11791](https://github.com/matrix-org/synapse/issues/11791))
+- Fix the `/_matrix/client/v1/room/{roomId}/hierarchy` endpoint returning incorrect fields which have been present since Synapse 1.49.0. ([\#11667](https://github.com/matrix-org/synapse/issues/11667))
+- Fix preview of some GIF URLs (like tenor.com). Contributed by Philippe Daouadi. ([\#11669](https://github.com/matrix-org/synapse/issues/11669))
+- Fix a bug where only the first 50 rooms from a space were returned from the `/hierarchy` API. This has existed since the introduction of the API in Synapse v1.41.0. ([\#11695](https://github.com/matrix-org/synapse/issues/11695))
+- Fix a bug introduced in Synapse v1.18.0 where password reset and address validation emails would not be sent if their subject was configured to use the 'app' template variable. Contributed by @br4nnigan. ([\#11710](https://github.com/matrix-org/synapse/issues/11710), [\#11745](https://github.com/matrix-org/synapse/issues/11745))
+- Make the 'List Rooms' Admin API sort stable. Contributed by Daniƫl Sonck. ([\#11737](https://github.com/matrix-org/synapse/issues/11737))
+- Fix a long-standing bug where space hierarchy over federation would only work correctly some of the time. ([\#11775](https://github.com/matrix-org/synapse/issues/11775))
+- Fix a bug introduced in Synapse v1.46.0 that prevented `on_logged_out` module callbacks from being correctly awaited by Synapse. ([\#11786](https://github.com/matrix-org/synapse/issues/11786))
+
+
+Improved Documentation
+----------------------
+
+- Warn against using a Let's Encrypt certificate for TLS/DTLS TURN server client connections, and suggest using ZeroSSL certificate instead. This works around client-side connectivity errors caused by WebRTC libraries that reject Let's Encrypt certificates. Contibuted by @AndrewFerr. ([\#11686](https://github.com/matrix-org/synapse/issues/11686))
+- Document the new `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable in the contributing guide. ([\#11715](https://github.com/matrix-org/synapse/issues/11715))
+- Document that the minimum supported PostgreSQL version is now 10. ([\#11725](https://github.com/matrix-org/synapse/issues/11725))
+- Fix typo in demo docs: differnt. ([\#11735](https://github.com/matrix-org/synapse/issues/11735))
+- Update room spec URL in config files. ([\#11739](https://github.com/matrix-org/synapse/issues/11739))
+- Mention `python3-venv` and `libpq-dev` dependencies in the contribution guide. ([\#11740](https://github.com/matrix-org/synapse/issues/11740))
+- Update documentation for configuring login with Facebook. ([\#11755](https://github.com/matrix-org/synapse/issues/11755))
+- Update installation instructions to note that Python 3.6 is no longer supported. ([\#11781](https://github.com/matrix-org/synapse/issues/11781))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove the unstable `/send_relation` endpoint. ([\#11682](https://github.com/matrix-org/synapse/issues/11682))
+- Remove `python_twisted_reactor_pending_calls` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724))
+- Remove the `password_hash` field from the response dictionaries of the [Users Admin API](https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html). ([\#11576](https://github.com/matrix-org/synapse/issues/11576))
+- **Deprecate support for `webclient` listeners and non-HTTP(S) `web_client_location` configuration. ([\#11774](https://github.com/matrix-org/synapse/issues/11774), [\#11783](https://github.com/matrix-org/synapse/issues/11783))**
+
+
+Internal Changes
+----------------
+
+- Run `pyupgrade --py37-plus --keep-percent-format` on Synapse. ([\#11685](https://github.com/matrix-org/synapse/issues/11685))
+- Use buildkit's cache feature to speed up docker builds. ([\#11691](https://github.com/matrix-org/synapse/issues/11691))
+- Use `auto_attribs` and native type hints for attrs classes. ([\#11692](https://github.com/matrix-org/synapse/issues/11692), [\#11768](https://github.com/matrix-org/synapse/issues/11768))
+- Remove debug logging for #4422, which has been closed since Synapse 0.99. ([\#11693](https://github.com/matrix-org/synapse/issues/11693))
+- Remove fallback code for Python 2. ([\#11699](https://github.com/matrix-org/synapse/issues/11699))
+- Add a test for [an edge case](https://github.com/matrix-org/synapse/pull/11532#discussion_r769104461) in the `/sync` logic. ([\#11701](https://github.com/matrix-org/synapse/issues/11701))
+- Add the option to write SQLite test dbs to disk when running tests. ([\#11702](https://github.com/matrix-org/synapse/issues/11702))
+- Improve Complement test output for Gitub Actions. ([\#11707](https://github.com/matrix-org/synapse/issues/11707))
+- Fix docstring on `add_account_data_for_user`. ([\#11716](https://github.com/matrix-org/synapse/issues/11716))
+- Complement environment variable name change and update `.gitignore`. ([\#11718](https://github.com/matrix-org/synapse/issues/11718))
+- Simplify calculation of Prometheus metrics for garbage collection. ([\#11723](https://github.com/matrix-org/synapse/issues/11723))
+- Improve accuracy of `python_twisted_reactor_tick_time` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724), [\#11771](https://github.com/matrix-org/synapse/issues/11771))
+- Minor efficiency improvements when inserting many values into the database. ([\#11742](https://github.com/matrix-org/synapse/issues/11742))
+- Invite PR authors to give themselves credit in the changelog. ([\#11744](https://github.com/matrix-org/synapse/issues/11744))
+- Add optional debugging to investigate [issue 8631](https://github.com/matrix-org/synapse/issues/8631). ([\#11760](https://github.com/matrix-org/synapse/issues/11760))
+- Remove `log_function` utility function and its uses. ([\#11761](https://github.com/matrix-org/synapse/issues/11761))
+- Add a unit test that checks both `client` and `webclient` resources will function when simultaneously enabled. ([\#11765](https://github.com/matrix-org/synapse/issues/11765))
+- Allow overriding complement commit using `COMPLEMENT_REF`. ([\#11766](https://github.com/matrix-org/synapse/issues/11766))
+- Add some comments and type annotations for `_update_outliers_txn`. ([\#11776](https://github.com/matrix-org/synapse/issues/11776))
+
+
Synapse 1.50.2 (2022-01-24)
===========================
Bugfixes
--------
-- Fix a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806))
+- Backport the sole fix from v1.51.0rc2. This fixes a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806))
Synapse 1.50.1 (2022-01-18)
diff --git a/contrib/prometheus/consoles/synapse.html b/contrib/prometheus/consoles/synapse.html
index cd9ad1523..d17c8a08d 100644
--- a/contrib/prometheus/consoles/synapse.html
+++ b/contrib/prometheus/consoles/synapse.html
@@ -92,22 +92,6 @@ new PromConsole.Graph({
})
-
Pending calls per tick
-
-
-
Storage
Queries
diff --git a/debian/changelog b/debian/changelog
index c790c2187..3a598c414 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,21 @@
+matrix-synapse-py3 (1.51.0) stable; urgency=medium
+
+ * New synapse release 1.51.0.
+
+ -- Synapse Packaging team Tue, 25 Jan 2022 11:28:51 +0000
+
+matrix-synapse-py3 (1.51.0~rc2) stable; urgency=medium
+
+ * New synapse release 1.51.0~rc2.
+
+ -- Synapse Packaging team Mon, 24 Jan 2022 12:25:00 +0000
+
+matrix-synapse-py3 (1.51.0~rc1) stable; urgency=medium
+
+ * New synapse release 1.51.0~rc1.
+
+ -- Synapse Packaging team Fri, 21 Jan 2022 10:46:02 +0000
+
matrix-synapse-py3 (1.50.2) stable; urgency=medium
* New synapse release 1.50.2.
diff --git a/demo/README b/demo/README
index 0bec820ad..a5a95bd19 100644
--- a/demo/README
+++ b/demo/README
@@ -22,5 +22,5 @@ Logs and sqlitedb will be stored in demo/808{0,1,2}.{log,db}
-Also note that when joining a public room on a differnt HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name.
+Also note that when joining a public room on a different HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 2bdc607e6..306f75ae5 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,14 +1,17 @@
# Dockerfile to build the matrixdotorg/synapse docker images.
#
+# Note that it uses features which are only available in BuildKit - see
+# https://docs.docker.com/go/buildkit/ for more information.
+#
# To build the image, run `docker build` command from the root of the
# synapse repository:
#
-# docker build -f docker/Dockerfile .
+# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile .
#
# There is an optional PYTHON_VERSION build argument which sets the
# version of python to build against: for example:
#
-# docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.6 .
+# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 .
#
ARG PYTHON_VERSION=3.8
@@ -19,7 +22,16 @@ ARG PYTHON_VERSION=3.8
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
-RUN apt-get update && apt-get install -y \
+#
+# RUN --mount is specific to buildkit and is documented at
+# https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md#build-mounts-run---mount.
+# Here we use it to set up a cache for apt, to improve rebuild speeds on
+# slow connections.
+#
+RUN \
+ --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ apt-get update && apt-get install -y \
build-essential \
libffi-dev \
libjpeg-dev \
@@ -44,7 +56,8 @@ COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py
# used while you develop on the source
#
# This is aiming at installing the `install_requires` and `extras_require` from `setup.py`
-RUN pip install --prefix="/install" --no-warn-script-location \
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install --prefix="/install" --no-warn-script-location \
/synapse[all]
# Copy over the rest of the project
@@ -66,7 +79,10 @@ LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/syna
LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git'
LABEL org.opencontainers.image.licenses='Apache-2.0'
-RUN apt-get update && apt-get install -y \
+RUN \
+ --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ apt-get update && apt-get install -y \
curl \
gosu \
libjpeg62-turbo \
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 74933d2fc..c514cadb9 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -15,9 +15,10 @@ server admin: [Admin API](../usage/administration/admin_api)
It returns a JSON body like the following:
-```json
+```jsonc
{
- "displayname": "User",
+ "name": "@user:example.com",
+ "displayname": "User", // can be null if not set
"threepids": [
{
"medium": "email",
@@ -32,11 +33,11 @@ It returns a JSON body like the following:
"validated_at": 1586458409743
}
],
- "avatar_url": "",
+ "avatar_url": "", // can be null if not set
+ "is_guest": 0,
"admin": 0,
"deactivated": 0,
"shadow_banned": 0,
- "password_hash": "$2b$12$p9B4GkqYdRTPGD",
"creation_ts": 1560432506,
"appservice_id": null,
"consent_server_notice_sent": null,
diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md
index abdb80843..c14298169 100644
--- a/docs/development/contributing_guide.md
+++ b/docs/development/contributing_guide.md
@@ -20,7 +20,9 @@ recommended for development. More information about WSL can be found at
. Running Synapse natively
on Windows is not officially supported.
-The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download).
+The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://www.python.org/downloads/). Your Python also needs support for [virtual environments](https://docs.python.org/3/library/venv.html). This is usually built-in, but some Linux distributions like Debian and Ubuntu split it out into its own package. Running `sudo apt install python3-venv` should be enough.
+
+Synapse can connect to PostgreSQL via the [psycopg2](https://pypi.org/project/psycopg2/) Python library. Building this library from source requires access to PostgreSQL's C header files. On Debian or Ubuntu Linux, these can be installed with `sudo apt install libpq-dev`.
The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git).
@@ -169,6 +171,27 @@ To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`:
SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests
```
+By default, tests will use an in-memory SQLite database for test data. For additional
+help with debugging, one can use an on-disk SQLite database file instead, in order to
+review database state during and after running tests. This can be done by setting
+the `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable. Doing so will cause the
+database state to be stored in a file named `test.db` under the trial process'
+working directory. Typically, this ends up being `_trial_temp/test.db`. For example:
+
+```sh
+SYNAPSE_TEST_PERSIST_SQLITE_DB=1 trial tests
+```
+
+The database file can then be inspected with:
+
+```sh
+sqlite3 _trial_temp/test.db
+```
+
+Note that the database file is cleared at the beginning of each test run. Thus it
+will always only contain the data generated by the *last run test*. Though generally
+when debugging, one is only running a single test anyway.
+
### Running tests under PostgreSQL
Invoking `trial` as above will use an in-memory SQLite database. This is great for
diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md
index aff381360..154b9a5e1 100644
--- a/docs/development/url_previews.md
+++ b/docs/development/url_previews.md
@@ -35,7 +35,12 @@ When Synapse is asked to preview a URL it does the following:
5. If the media is HTML:
1. Decodes the HTML via the stored file.
2. Generates an Open Graph response from the HTML.
- 3. If an image exists in the Open Graph response:
+ 3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+ 1. Downloads the URL and stores it into a file via the media storage provider
+ and saves the local media metadata.
+ 2. Convert the oEmbed response to an Open Graph response.
+ 3. Override any Open Graph data from the HTML with data from oEmbed.
+ 4. If an image exists in the Open Graph response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
diff --git a/docs/openid.md b/docs/openid.md
index ff9de9d5b..171ea3b71 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -390,9 +390,6 @@ oidc_providers:
### Facebook
-Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
-one so requires a little more configuration.
-
0. You will need a Facebook developer account. You can register for one
[here](https://developers.facebook.com/async/registration/).
1. On the [apps](https://developers.facebook.com/apps/) page of the developer
@@ -412,24 +409,28 @@ Synapse config:
idp_name: Facebook
idp_brand: "facebook" # optional: styling hint for clients
discover: false
- issuer: "https://facebook.com"
+ issuer: "https://www.facebook.com"
client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED
scopes: ["openid", "email"]
- authorization_endpoint: https://facebook.com/dialog/oauth
- token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
- user_profile_method: "userinfo_endpoint"
- userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+ authorization_endpoint: "https://facebook.com/dialog/oauth"
+ token_endpoint: "https://graph.facebook.com/v9.0/oauth/access_token"
+ jwks_uri: "https://www.facebook.com/.well-known/oauth/openid/jwks/"
user_mapping_provider:
config:
- subject_claim: "id"
display_name_template: "{{ user.name }}"
+ email_template: "{{ '{{ user.email }}' }}"
```
Relevant documents:
- * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
- * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
- * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
+ * [Manually Build a Login Flow](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow)
+ * [Using Facebook's Graph API](https://developers.facebook.com/docs/graph-api/using-graph-api/)
+ * [Reference to the User endpoint](https://developers.facebook.com/docs/graph-api/reference/user)
+
+Facebook do have an [OIDC discovery endpoint](https://www.facebook.com/.well-known/openid-configuration),
+but it has a `response_types_supported` which excludes "code" (which we rely on, and
+is even mentioned in their [documentation](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow#login)),
+so we have to disable discovery and configure the URIs manually.
### Gitea
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 810a14b07..1b86d0295 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -74,13 +74,7 @@ server_name: "SERVERNAME"
#
pid_file: DATADIR/homeserver.pid
-# The absolute URL to the web client which /_matrix/client will redirect
-# to if 'webclient' is configured under the 'listeners' configuration.
-#
-# This option can be also set to the filesystem path to the web client
-# which will be served at /_matrix/client/ if 'webclient' is configured
-# under the 'listeners' configuration, however this is a security risk:
-# https://github.com/matrix-org/synapse#security-note
+# The absolute URL to the web client which / will redirect to.
#
#web_client_location: https://riot.example.com/
@@ -164,7 +158,7 @@ presence:
# The default room version for newly created rooms.
#
# Known room versions are listed here:
-# https://matrix.org/docs/spec/#complete-list-of-room-versions
+# https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
@@ -310,8 +304,6 @@ presence:
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
#
-# webclient: A web client. Requires web_client_location to be set.
-#
listeners:
# TLS-enabled listener: for when matrix traffic is sent directly to synapse.
#
@@ -1503,6 +1495,21 @@ room_prejoin_state:
#additional_event_types:
# - org.example.custom.event.type
+# We record the IP address of clients used to access the API for various
+# reasons, including displaying it to the user in the "Where you're signed in"
+# dialog.
+#
+# By default, when puppeting another user via the admin API, the client IP
+# address is recorded against the user who created the access token (ie, the
+# admin user), and *not* the puppeted user.
+#
+# Uncomment the following to also record the IP address against the puppeted
+# user. (This also means that the puppeted user will count as an "active" user
+# for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc
+# above.)
+#
+#track_puppeted_user_ips: true
+
# A list of application service config files to use
#
@@ -1870,10 +1877,13 @@ saml2_config:
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
-# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+# endpoint, or to rely on the data returned in the id_token from the
+# token_endpoint.
#
-# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
-# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+# Valid values are: 'auto' or 'userinfo_endpoint'.
+#
+# Defaults to 'auto', which uses the userinfo endpoint if 'openid' is
+# not included in 'scopes'. Set to 'userinfo_endpoint' to always use the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
diff --git a/docs/setup/installation.md b/docs/setup/installation.md
index 210c80dac..fe657a15d 100644
--- a/docs/setup/installation.md
+++ b/docs/setup/installation.md
@@ -194,7 +194,7 @@ When following this route please make sure that the [Platform-specific prerequis
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
-- Python 3.6 or later, up to Python 3.9.
+- Python 3.7 or later, up to Python 3.9.
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
To install the Synapse homeserver run:
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index e32aaa185..eba7ca612 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -137,6 +137,10 @@ This will install and start a systemd service called `coturn`.
# TLS private key file
pkey=/path/to/privkey.pem
+
+ # Ensure the configuration lines that disable TLS/DTLS are commented-out or removed
+ #no-tls
+ #no-dtls
```
In this case, replace the `turn:` schemes in the `turn_uris` settings below
@@ -145,6 +149,14 @@ This will install and start a systemd service called `coturn`.
We recommend that you only try to set up TLS/DTLS once you have set up a
basic installation and got it working.
+ NB: If your TLS certificate was provided by Let's Encrypt, TLS/DTLS will
+ not work with any Matrix client that uses Chromium's WebRTC library. This
+ currently includes Element Android & iOS; for more details, see their
+ [respective](https://github.com/vector-im/element-android/issues/1533)
+ [issues](https://github.com/vector-im/element-ios/issues/2712) as well as the underlying
+ [WebRTC issue](https://bugs.chromium.org/p/webrtc/issues/detail?id=11710).
+ Consider using a ZeroSSL certificate for your TURN server as a working alternative.
+
1. Ensure your firewall allows traffic into the TURN server on the ports
you've configured it to listen on (By default: 3478 and 5349 for TURN
traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
@@ -250,6 +262,10 @@ Here are a few things to try:
* Check that you have opened your firewall to allow UDP traffic to the UDP
relay ports (49152-65535 by default).
+ * Try disabling `coturn`'s TLS/DTLS listeners and enable only its (unencrypted)
+ TCP/UDP listeners. (This will only leave signaling traffic unencrypted;
+ voice & video WebRTC traffic is always encrypted.)
+
* Some WebRTC implementations (notably, that of Google Chrome) appear to get
confused by TURN servers which are reachable over IPv6 (this appears to be
an unexpected side-effect of its handling of multiple IP addresses as
diff --git a/docs/upgrade.md b/docs/upgrade.md
index 30bb0dcd9..f455d257b 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,17 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.51.0
+
+## Deprecation of `webclient` listeners and non-HTTP(S) `web_client_location`
+
+Listeners of type `webclient` are deprecated and scheduled to be removed in
+Synapse v1.53.0.
+
+Similarly, a non-HTTP(S) `web_client_location` configuration is deprecated and
+will become a configuration error in Synapse v1.53.0.
+
+
# Upgrading to v1.50.0
## Dropping support for old Python and Postgres versions
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 53295b58f..e08ffedaf 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -8,7 +8,8 @@
# By default the script will fetch the latest Complement master branch and
# run tests with that. This can be overridden to use a custom Complement
# checkout by setting the COMPLEMENT_DIR environment variable to the
-# filepath of a local Complement checkout.
+# filepath of a local Complement checkout or by setting the COMPLEMENT_REF
+# environment variable to pull a different branch or commit.
#
# By default Synapse is run in monolith mode. This can be overridden by
# setting the WORKERS environment variable.
@@ -23,16 +24,20 @@
# Exit if a line returns a non-zero exit code
set -e
+# enable buildkit for the docker builds
+export DOCKER_BUILDKIT=1
+
# Change to the repository root
cd "$(dirname $0)/.."
# Check for a user-specified Complement checkout
if [[ -z "$COMPLEMENT_DIR" ]]; then
- echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
- wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
- tar -xzf master.tar.gz
- COMPLEMENT_DIR=complement-master
- echo "Checkout available at 'complement-master'"
+ COMPLEMENT_REF=${COMPLEMENT_REF:-master}
+ echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..."
+ wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz
+ tar -xzf ${COMPLEMENT_REF}.tar.gz
+ COMPLEMENT_DIR=complement-${COMPLEMENT_REF}
+ echo "Checkout available at 'complement-${COMPLEMENT_REF}'"
fi
# Build the base Synapse image from the local checkout
@@ -47,7 +52,7 @@ if [[ -n "$WORKERS" ]]; then
COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile
# And provide some more configuration to complement.
export COMPLEMENT_CA=true
- export COMPLEMENT_VERSION_CHECK_ITERATIONS=500
+ export COMPLEMENT_SPAWN_HS_TIMEOUT_SECS=25
else
export COMPLEMENT_BASE_IMAGE=complement-synapse
COMPLEMENT_DOCKERFILE=Synapse.Dockerfile
@@ -65,4 +70,5 @@ if [[ -n "$1" ]]; then
fi
# Run the tests!
+echo "Images built; running complement"
go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 5ef294cd4..26bdfec33 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.50.2"
+__version__ = "1.51.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index 093af4327..e207f154f 100644
--- a/synapse/_scripts/review_recent_signups.py
+++ b/synapse/_scripts/review_recent_signups.py
@@ -46,7 +46,9 @@ class UserInfo:
ips: List[str] = attr.Factory(list)
-def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
+def get_recent_users(
+ txn: LoggingTransaction, since_ms: int, exclude_app_service: bool
+) -> List[UserInfo]:
"""Fetches recently registered users and some info on them."""
sql = """
@@ -56,6 +58,9 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
AND deactivated = 0
"""
+ if exclude_app_service:
+ sql += " AND appservice_id IS NULL"
+
txn.execute(sql, (since_ms / 1000,))
user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn]
@@ -113,7 +118,7 @@ def main() -> None:
"-e",
"--exclude-emails",
action="store_true",
- help="Exclude users that have validated email addresses",
+ help="Exclude users that have validated email addresses.",
)
parser.add_argument(
"-u",
@@ -121,6 +126,12 @@ def main() -> None:
action="store_true",
help="Only print user IDs that match.",
)
+ parser.add_argument(
+ "-a",
+ "--exclude-app-service",
+ help="Exclude appservice users.",
+ action="store_true",
+ )
config = ReviewConfig()
@@ -133,6 +144,7 @@ def main() -> None:
since_ms = time.time() * 1000 - Config.parse_duration(config_args.since)
exclude_users_with_email = config_args.exclude_emails
+ exclude_users_with_appservice = config_args.exclude_app_service
include_context = not config_args.only_users
for database_config in config.database.databases:
@@ -143,7 +155,7 @@ def main() -> None:
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
# This generates a type of Cursor, not LoggingTransaction.
- user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type]
+ user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type]
for user_info in user_infos:
if exclude_users_with_email and user_info.emails:
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 4a32d430b..683241201 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -71,6 +71,7 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+ self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -246,6 +247,18 @@ class Auth:
user_agent=user_agent,
device_id=device_id,
)
+ # Track also the puppeted user client IP if enabled and the user is puppeting
+ if (
+ user_info.user_id != user_info.token_owner
+ and self._track_puppeted_user_ips
+ ):
+ await self.store.insert_client_ip(
+ user_id=user_info.user_id,
+ access_token=access_token,
+ ip=ip_addr,
+ user_agent=user_agent,
+ device_id=device_id,
+ )
if is_guest and not allow_guest:
raise AuthError(
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 0a895bba4..a747a4081 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -46,41 +46,41 @@ class RoomDisposition:
UNSTABLE = "unstable"
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomVersion:
"""An object which describes the unique attributes of a room version."""
- identifier = attr.ib(type=str) # the identifier for this version
- disposition = attr.ib(type=str) # one of the RoomDispositions
- event_format = attr.ib(type=int) # one of the EventFormatVersions
- state_res = attr.ib(type=int) # one of the StateResolutionVersions
- enforce_key_validity = attr.ib(type=bool)
+ identifier: str # the identifier for this version
+ disposition: str # one of the RoomDispositions
+ event_format: int # one of the EventFormatVersions
+ state_res: int # one of the StateResolutionVersions
+ enforce_key_validity: bool
# Before MSC2432, m.room.aliases had special auth rules and redaction rules
- special_case_aliases_auth = attr.ib(type=bool)
+ special_case_aliases_auth: bool
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
# * Floats
# * NaN, Infinity, -Infinity
- strict_canonicaljson = attr.ib(type=bool)
+ strict_canonicaljson: bool
# MSC2209: Check 'notifications' key while verifying
# m.room.power_levels auth rules.
- limit_notifications_power_levels = attr.ib(type=bool)
+ limit_notifications_power_levels: bool
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
- msc2176_redaction_rules = attr.ib(type=bool)
+ msc2176_redaction_rules: bool
# MSC3083: Support the 'restricted' join_rule.
- msc3083_join_rules = attr.ib(type=bool)
+ msc3083_join_rules: bool
# MSC3375: Support for the proper redaction rules for MSC3083. This mustn't
# be enabled if MSC3083 is not.
- msc3375_redaction_rules = attr.ib(type=bool)
+ msc3375_redaction_rules: bool
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
- msc2403_knocking = attr.ib(type=bool)
+ msc2403_knocking: bool
# MSC2716: Adds m.room.power_levels -> content.historical field to control
# whether "insertion", "chunk", "marker" events can be sent
- msc2716_historical = attr.ib(type=bool)
+ msc2716_historical: bool
# MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
- msc2716_redactions = attr.ib(type=bool)
+ msc2716_redactions: bool
class RoomVersions:
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 5fc59c1be..579adbbca 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -60,7 +60,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
-from synapse.metrics import register_threadpool
+from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.types import ISynapseReactor
@@ -159,6 +159,7 @@ def start_reactor(
change_resource_limit(soft_file_limit)
if gc_thresholds:
gc.set_threshold(*gc_thresholds)
+ install_gc_manager()
run_command()
# make sure that we run the reactor with the sentinel log context,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dd76e0732..efedcc888 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -131,9 +131,18 @@ class SynapseHomeServer(HomeServer):
resources.update(self._module_web_resources)
self._module_web_resources_consumed = True
- # try to find something useful to redirect '/' to
- if WEB_CLIENT_PREFIX in resources:
- root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
+ # Try to find something useful to serve at '/':
+ #
+ # 1. Redirect to the web client if it is an HTTP(S) URL.
+ # 2. Redirect to the web client served via Synapse.
+ # 3. Redirect to the static "Synapse is running" page.
+ # 4. Do not redirect and use a blank resource.
+ if self.config.server.web_client_location_is_redirect:
+ root_resource: Resource = RootOptionsRedirectResource(
+ self.config.server.web_client_location
+ )
+ elif WEB_CLIENT_PREFIX in resources:
+ root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
@@ -262,15 +271,15 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
+ # webclient listeners are deprecated as of Synapse v1.51.0, remove it
+ # in > v1.53.0.
webclient_loc = self.config.server.web_client_location
if webclient_loc is None:
logger.warning(
"Not enabling webclient resource, as web_client_location is unset."
)
- elif webclient_loc.startswith("http://") or webclient_loc.startswith(
- "https://"
- ):
+ elif self.config.server.web_client_location_is_redirect:
resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc)
else:
logger.warning(
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 25538b82d..8133b6b62 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -29,6 +29,7 @@ class ApiConfig(Config):
def read_config(self, config: JsonDict, **kwargs):
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+ self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
def generate_config_section(cls, **kwargs) -> str:
formatted_default_state_types = "\n".join(
@@ -59,6 +60,21 @@ class ApiConfig(Config):
#
#additional_event_types:
# - org.example.custom.event.type
+
+ # We record the IP address of clients used to access the API for various
+ # reasons, including displaying it to the user in the "Where you're signed in"
+ # dialog.
+ #
+ # By default, when puppeting another user via the admin API, the client IP
+ # address is recorded against the user who created the access token (ie, the
+ # admin user), and *not* the puppeted user.
+ #
+ # Uncomment the following to also record the IP address against the puppeted
+ # user. (This also means that the puppeted user will count as an "active" user
+ # for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc
+ # above.)
+ #
+ #track_puppeted_user_ips: true
""" % {
"formatted_default_state_types": formatted_default_state_types
}
@@ -138,5 +154,8 @@ _MAIN_SCHEMA = {
"properties": {
"room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
"room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
+ "track_puppeted_user_ips": {
+ "type": "boolean",
+ },
},
}
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 510b647c6..949d7dd5a 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -55,19 +55,19 @@ https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EmailSubjectConfig:
- message_from_person_in_room = attr.ib(type=str)
- message_from_person = attr.ib(type=str)
- messages_from_person = attr.ib(type=str)
- messages_in_room = attr.ib(type=str)
- messages_in_room_and_others = attr.ib(type=str)
- messages_from_person_and_others = attr.ib(type=str)
- invite_from_person = attr.ib(type=str)
- invite_from_person_to_room = attr.ib(type=str)
- invite_from_person_to_space = attr.ib(type=str)
- password_reset = attr.ib(type=str)
- email_validation = attr.ib(type=str)
+ message_from_person_in_room: str
+ message_from_person: str
+ messages_from_person: str
+ messages_in_room: str
+ messages_in_room_and_others: str
+ messages_from_person_and_others: str
+ invite_from_person: str
+ invite_from_person_to_room: str
+ invite_from_person_to_space: str
+ password_reset: str
+ email_validation: str
class EmailConfig(Config):
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 79c400fe3..e783b1131 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -148,10 +148,13 @@ class OIDCConfig(Config):
# Defaults to false. Avoid this in production.
#
# user_profile_method: Whether to fetch the user profile from the userinfo
- # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+ # endpoint, or to rely on the data returned in the id_token from the
+ # token_endpoint.
#
- # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
- # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+ # Valid values are: 'auto' or 'userinfo_endpoint'.
+ #
+ # Defaults to 'auto', which uses the userinfo endpoint if 'openid' is
+ # not included in 'scopes'. Set to 'userinfo_endpoint' to always use the
# userinfo endpoint.
#
# allow_existing_users: set to 'true' to allow a user logging in via OIDC to
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 1de2dea9b..f200d0c1f 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -200,8 +200,8 @@ class HttpListenerConfig:
"""Object describing the http-specific parts of the config of a listener"""
x_forwarded: bool = False
- resources: List[HttpResourceConfig] = attr.ib(factory=list)
- additional_resources: Dict[str, dict] = attr.ib(factory=dict)
+ resources: List[HttpResourceConfig] = attr.Factory(list)
+ additional_resources: Dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
@@ -259,7 +259,6 @@ class ServerConfig(Config):
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file"))
- self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config.get("soft_file_limit", 0)
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
@@ -506,8 +505,17 @@ class ServerConfig(Config):
l2.append(listener)
self.listeners = l2
- if not self.web_client_location:
- _warn_if_webclient_configured(self.listeners)
+ self.web_client_location = config.get("web_client_location", None)
+ self.web_client_location_is_redirect = self.web_client_location and (
+ self.web_client_location.startswith("http://")
+ or self.web_client_location.startswith("https://")
+ )
+ # A non-HTTP(S) web client location is deprecated.
+ if self.web_client_location and not self.web_client_location_is_redirect:
+ logger.warning(NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING)
+
+ # Warn if webclient is configured for a worker.
+ _warn_if_webclient_configured(self.listeners)
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
@@ -793,13 +801,7 @@ class ServerConfig(Config):
#
pid_file: %(pid_file)s
- # The absolute URL to the web client which /_matrix/client will redirect
- # to if 'webclient' is configured under the 'listeners' configuration.
- #
- # This option can be also set to the filesystem path to the web client
- # which will be served at /_matrix/client/ if 'webclient' is configured
- # under the 'listeners' configuration, however this is a security risk:
- # https://github.com/matrix-org/synapse#security-note
+ # The absolute URL to the web client which / will redirect to.
#
#web_client_location: https://riot.example.com/
@@ -883,7 +885,7 @@ class ServerConfig(Config):
# The default room version for newly created rooms.
#
# Known room versions are listed here:
- # https://matrix.org/docs/spec/#complete-list-of-room-versions
+ # https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
@@ -1011,8 +1013,6 @@ class ServerConfig(Config):
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
#
- # webclient: A web client. Requires web_client_location to be set.
- #
listeners:
# TLS-enabled listener: for when matrix traffic is sent directly to synapse.
#
@@ -1349,9 +1349,15 @@ def parse_listener_def(listener: Any) -> ListenerConfig:
return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING = """
+Synapse no longer supports serving a web client. To remove this warning,
+configure 'web_client_location' with an HTTP(S) URL.
+"""
+
+
NO_MORE_WEB_CLIENT_WARNING = """
-Synapse no longer includes a web client. To enable a web client, configure
-web_client_location. To remove this warning, remove 'webclient' from the 'listeners'
+Synapse no longer includes a web client. To redirect the root resource to a web client, configure
+'web_client_location'. To remove this warning, remove 'webclient' from the 'listeners'
configuration.
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 576f51918..bdaba6db3 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -51,12 +51,12 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj
-@attr.s
+@attr.s(auto_attribs=True)
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication."""
- host = attr.ib(type=str)
- port = attr.ib(type=int)
+ host: str
+ port: int
@attr.s
@@ -77,34 +77,28 @@ class WriterLocations:
can only be a single instance.
"""
- events = attr.ib(
+ events: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- typing = attr.ib(
+ typing: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- to_device = attr.ib(
+ to_device: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- account_data = attr.ib(
+ account_data: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- receipts = attr.ib(
+ receipts: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
- presence = attr.ib(
+ presence: List[str] = attr.ib(
default=["master"],
- type=List[str],
converter=_instance_to_list_converter,
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 993b04099..72d4a69aa 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-@attr.s(slots=True, cmp=False)
+@attr.s(slots=True, frozen=True, cmp=False, auto_attribs=True)
class VerifyJsonRequest:
"""
A request to verify a JSON object.
@@ -78,10 +78,10 @@ class VerifyJsonRequest:
key_ids: The set of key_ids to that could be used to verify the JSON object
"""
- server_name = attr.ib(type=str)
- get_json_object = attr.ib(type=Callable[[], JsonDict])
- minimum_valid_until_ts = attr.ib(type=int)
- key_ids = attr.ib(type=List[str])
+ server_name: str
+ get_json_object: Callable[[], JsonDict]
+ minimum_valid_until_ts: int
+ key_ids: List[str]
@staticmethod
def from_json_object(
@@ -124,7 +124,7 @@ class KeyLookupError(ValueError):
pass
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _FetchKeyRequest:
"""A request for keys for a given server.
@@ -138,9 +138,9 @@ class _FetchKeyRequest:
key_ids: The IDs of the keys to attempt to fetch
"""
- server_name = attr.ib(type=str)
- minimum_valid_until_ts = attr.ib(type=int)
- key_ids = attr.ib(type=List[str])
+ server_name: str
+ minimum_valid_until_ts: int
+ key_ids: List[str]
class Keyring:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f251402ed..0eab1aefd 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class EventContext:
"""
Holds information relevant to persisting an event
@@ -103,15 +103,15 @@ class EventContext:
accessed via get_prev_state_ids.
"""
- rejected = attr.ib(default=False, type=Union[bool, str])
- _state_group = attr.ib(default=None, type=Optional[int])
- state_group_before_event = attr.ib(default=None, type=Optional[int])
- prev_group = attr.ib(default=None, type=Optional[int])
- delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
- app_service = attr.ib(default=None, type=Optional[ApplicationService])
+ rejected: Union[bool, str] = False
+ _state_group: Optional[int] = None
+ state_group_before_event: Optional[int] = None
+ prev_group: Optional[int] = None
+ delta_ids: Optional[StateMap[str]] = None
+ app_service: Optional[ApplicationService] = None
- _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
- _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ _current_state_ids: Optional[StateMap[str]] = None
+ _prev_state_ids: Optional[StateMap[str]] = None
@staticmethod
def with_state(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 2038e7292..918adeecf 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,17 +14,7 @@
# limitations under the License.
import collections.abc
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Union,
-)
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from frozendict import frozendict
@@ -32,14 +22,10 @@ from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
-from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze
from . import EventBase
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (? JsonDict:
"""Serializes a single event.
@@ -418,66 +399,41 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, **kwargs)
# Check if there are any bundled aggregations to include with the event.
- #
- # Do not bundle aggregations if any of the following at true:
- #
- # * Support is disabled via the configuration or the caller.
- # * The event is a state event.
- # * The event has been redacted.
- if (
- self._msc1849_enabled
- and bundle_aggregations
- and not event.is_state()
- and not event.internal_metadata.is_redacted()
- ):
- await self._injected_bundled_aggregations(event, time_now, serialized_event)
+ if bundle_aggregations:
+ event_aggregations = bundle_aggregations.get(event.event_id)
+ if event_aggregations:
+ self._inject_bundled_aggregations(
+ event,
+ time_now,
+ bundle_aggregations[event.event_id],
+ serialized_event,
+ )
return serialized_event
- async def _injected_bundled_aggregations(
- self, event: EventBase, time_now: int, serialized_event: JsonDict
+ def _inject_bundled_aggregations(
+ self,
+ event: EventBase,
+ time_now: int,
+ aggregations: JsonDict,
+ serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Args:
event: The event being serialized.
time_now: The current time in milliseconds
+ aggregations: The bundled aggregation to serialize.
serialized_event: The serialized event which may be modified.
"""
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return
+ # Make a copy in-case the object is cached.
+ aggregations = aggregations.copy()
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include.
- aggregations = {}
-
- annotations = await self.store.get_aggregation_groups_for_event(
- event_id, room_id
- )
- if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
-
- references = await self.store.get_relations_for_event(
- event_id, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
-
- edit = None
- if event.type == EventTypes.Message:
- edit = await self.store.get_applicable_edit(event_id, room_id)
-
- if edit:
+ if RelationTypes.REPLACE in aggregations:
# If there is an edit replace the content, preserving existing
# relations.
+ edit = aggregations[RelationTypes.REPLACE]
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
@@ -502,27 +458,19 @@ class EventClientSerializer:
}
# If this event is the start of a thread, include a summary of the replies.
- if self._msc3440_enabled:
- (
- thread_count,
- latest_thread_event,
- ) = await self.store.get_thread_summary(event_id, room_id)
- if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- # Don't bundle aggregations as this could recurse forever.
- "latest_event": await self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=False
- ),
- "count": thread_count,
- }
+ if RelationTypes.THREAD in aggregations:
+ # Serialize the latest thread event.
+ latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
- # If any bundled aggregations were found, include them.
- if aggregations:
- serialized_event["unsigned"].setdefault("m.relations", {}).update(
- aggregations
+ # Don't bundle aggregations as this could recurse forever.
+ aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
+ latest_thread_event, time_now, bundle_aggregations=None
)
- async def serialize_events(
+ # Include the bundled aggregations in the event.
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+
+ def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events.
@@ -535,9 +483,9 @@ class EventClientSerializer:
Returns:
The list of serialized events
"""
- return await yieldable_gather_results(
- self.serialize_event, events, time_now=time_now, **kwargs
- )
+ return [
+ self.serialize_event(event, time_now=time_now, **kwargs) for event in events
+ ]
def copy_power_levels_contents(
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index addc0bf00..896168c05 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -230,6 +230,10 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
# origin, etc etc)
assert_params_in_dict(pdu_json, ("type", "depth"))
+ # Strip any unauthorized values from "unsigned" if they exist
+ if "unsigned" in pdu_json:
+ _strip_unsigned_values(pdu_json)
+
depth = pdu_json["depth"]
if not isinstance(depth, int):
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
@@ -245,3 +249,24 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
event = make_event_from_dict(pdu_json, room_version)
return event
+
+
+def _strip_unsigned_values(pdu_dict: JsonDict) -> None:
+ """
+ Strip any unsigned values unless specifically allowed, as defined by the whitelist.
+
+ pdu: the json dict to strip values from. Note that the dict is mutated by this
+ function
+ """
+ unsigned = pdu_dict["unsigned"]
+
+ if not isinstance(unsigned, dict):
+ pdu_dict["unsigned"] = {}
+
+ if pdu_dict["type"] == "m.room.member":
+ whitelist = ["knock_room_state", "invite_room_state", "age"]
+ else:
+ whitelist = ["age"]
+
+ filtered_unsigned = {k: v for k, v in unsigned.items() if k in whitelist}
+ pdu_dict["unsigned"] = filtered_unsigned
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 6ea4edfc7..74f17aa4d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,7 +56,6 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -119,7 +118,8 @@ class FederationClient(FederationBase):
# It is a map of (room ID, suggested-only) -> the response of
# get_room_hierarchy.
self._get_room_hierarchy_cache: ExpiringCache[
- Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
+ Tuple[str, bool],
+ Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
] = ExpiringCache(
cache_name="get_room_hierarchy_cache",
clock=self._clock,
@@ -144,7 +144,6 @@ class FederationClient(FederationBase):
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
- @log_function
async def make_query(
self,
destination: str,
@@ -178,7 +177,6 @@ class FederationClient(FederationBase):
ignore_backoff=ignore_backoff,
)
- @log_function
async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
@@ -196,7 +194,6 @@ class FederationClient(FederationBase):
destination, content, timeout
)
- @log_function
async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
@@ -208,7 +205,6 @@ class FederationClient(FederationBase):
destination, user_id, timeout
)
- @log_function
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
@@ -1338,7 +1334,7 @@ class FederationClient(FederationBase):
destinations: Iterable[str],
room_id: str,
suggested_only: bool,
- ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+ ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
"""
Call other servers to get a hierarchy of the given room.
@@ -1353,7 +1349,8 @@ class FederationClient(FederationBase):
Returns:
A tuple of:
- The room as a JSON dictionary.
+ The room as a JSON dictionary, without a "children_state" key.
+ A list of `m.space.child` state events.
A list of children rooms, as JSON dictionaries.
A list of inaccessible children room IDs.
@@ -1368,7 +1365,7 @@ class FederationClient(FederationBase):
async def send_request(
destination: str,
- ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
+ ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
try:
res = await self.transport_layer.get_room_hierarchy(
destination=destination,
@@ -1397,7 +1394,7 @@ class FederationClient(FederationBase):
raise InvalidResponseError("'room' must be a dict")
# Validate children_state of the room.
- children_state = room.get("children_state", [])
+ children_state = room.pop("children_state", [])
if not isinstance(children_state, Sequence):
raise InvalidResponseError("'room.children_state' must be a list")
if any(not isinstance(e, dict) for e in children_state):
@@ -1426,7 +1423,7 @@ class FederationClient(FederationBase):
"Invalid room ID in 'inaccessible_children' list"
)
- return room, children, inaccessible_children
+ return room, children_state, children, inaccessible_children
try:
result = await self._try_destination_list(
@@ -1474,8 +1471,6 @@ class FederationClient(FederationBase):
if event.room_id == room_id:
children_events.append(event.data)
children_room_ids.add(event.state_key)
- # And add them under the requested room.
- requested_room["children_state"] = children_events
# Find the children rooms.
children = []
@@ -1485,7 +1480,7 @@ class FederationClient(FederationBase):
# It isn't clear from the response whether some of the rooms are
# not accessible.
- result = (requested_room, children, ())
+ result = (requested_room, children_events, children, ())
# Cache the result to avoid fetching data over federation every time.
self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ee71f289c..af9cb98f6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -58,7 +58,6 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
-from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -859,7 +858,6 @@ class FederationServer(FederationBase):
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
- @log_function
async def on_query_client_keys(
self, origin: str, content: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
@@ -940,7 +938,6 @@ class FederationServer(FederationBase):
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
- @log_function
async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 523ab1c51..60e2e6cf0 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -23,7 +23,6 @@ import logging
from typing import Optional, Tuple
from synapse.federation.units import Transaction
-from synapse.logging.utils import log_function
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
@@ -36,7 +35,6 @@ class TransactionActions:
def __init__(self, datastore: DataStore):
self.store = datastore
- @log_function
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
@@ -53,7 +51,6 @@ class TransactionActions:
return await self.store.get_received_txn_response(transaction_id, origin)
- @log_function
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 391b30fbb..8152e80b8 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -607,18 +607,18 @@ class PerDestinationQueue:
self._pending_pdus = []
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _TransactionQueueManager:
"""A helper async context manager for pulling stuff off the queues and
tracking what was last successfully sent, etc.
"""
- queue = attr.ib(type=PerDestinationQueue)
+ queue: PerDestinationQueue
- _device_stream_id = attr.ib(type=Optional[int], default=None)
- _device_list_id = attr.ib(type=Optional[int], default=None)
- _last_stream_ordering = attr.ib(type=Optional[int], default=None)
- _pdus = attr.ib(type=List[EventBase], factory=list)
+ _device_stream_id: Optional[int] = None
+ _device_list_id: Optional[int] = None
+ _last_stream_ordering: Optional[int] = None
+ _pdus: List[EventBase] = attr.Factory(list)
async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
# First we calculate the EDUs we want to send, if any.
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index ab935e5a7..742ee5725 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -35,6 +35,7 @@ if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
last_pdu_ts_metric = Gauge(
"synapse_federation_last_sent_pdu_time",
@@ -124,6 +125,17 @@ class TransactionManager:
len(pdus),
len(edus),
)
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ device_list_updates = [
+ edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS
+ ]
+ if device_list_updates:
+ issue_8631_logger.debug(
+ "about to send txn [%s] including device list updates: %s",
+ transaction.transaction_id,
+ device_list_updates,
+ )
# Actually send the transaction
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9fc4c31c9..8782586cd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,7 +44,6 @@ from synapse.api.urls import (
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
-from synapse.logging.utils import log_function
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -62,7 +61,6 @@ class TransportLayerClient:
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
- @log_function
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
@@ -88,7 +86,6 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
- @log_function
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
@@ -111,7 +108,6 @@ class TransportLayerClient:
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
- @log_function
async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
@@ -149,7 +145,6 @@ class TransportLayerClient:
destination, path=path, args=args, try_trailing_slash_on_400=True
)
- @log_function
async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str
) -> Union[JsonDict, List]:
@@ -185,7 +180,6 @@ class TransportLayerClient:
return remote_response
- @log_function
async def send_transaction(
self,
transaction: Transaction,
@@ -234,7 +228,6 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
- @log_function
async def make_query(
self,
destination: str,
@@ -254,7 +247,6 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff,
)
- @log_function
async def make_membership_event(
self,
destination: str,
@@ -317,7 +309,6 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff,
)
- @log_function
async def send_join_v1(
self,
room_version: RoomVersion,
@@ -336,7 +327,6 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- @log_function
async def send_join_v2(
self,
room_version: RoomVersion,
@@ -355,7 +345,6 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- @log_function
async def send_leave_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
@@ -372,7 +361,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def send_leave_v2(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> JsonDict:
@@ -389,7 +377,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def send_knock_v1(
self,
destination: str,
@@ -423,7 +410,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content
)
- @log_function
async def send_invite_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
@@ -433,7 +419,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def send_invite_v2(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> JsonDict:
@@ -443,7 +428,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def get_public_rooms(
self,
remote_server: str,
@@ -516,7 +500,6 @@ class TransportLayerClient:
return response
- @log_function
async def exchange_third_party_invite(
self, destination: str, room_id: str, event_dict: JsonDict
) -> JsonDict:
@@ -526,7 +509,6 @@ class TransportLayerClient:
destination=destination, path=path, data=event_dict
)
- @log_function
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
@@ -534,7 +516,6 @@ class TransportLayerClient:
return await self.client.get_json(destination=destination, path=path)
- @log_function
async def query_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
) -> JsonDict:
@@ -576,7 +557,6 @@ class TransportLayerClient:
destination=destination, path=path, data=query_content, timeout=timeout
)
- @log_function
async def query_user_devices(
self, destination: str, user_id: str, timeout: int
) -> JsonDict:
@@ -616,7 +596,6 @@ class TransportLayerClient:
destination=destination, path=path, timeout=timeout
)
- @log_function
async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
) -> JsonDict:
@@ -655,7 +634,6 @@ class TransportLayerClient:
destination=destination, path=path, data=query_content, timeout=timeout
)
- @log_function
async def get_missing_events(
self,
destination: str,
@@ -680,7 +658,6 @@ class TransportLayerClient:
timeout=timeout,
)
- @log_function
async def get_group_profile(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -694,7 +671,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def update_group_profile(
self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
@@ -716,7 +692,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_group_summary(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -730,7 +705,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_rooms_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -798,7 +772,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_users_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -812,7 +785,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_invited_users_in_group(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -826,7 +798,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def accept_group_invite(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -837,7 +808,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
def join_group(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> Awaitable[JsonDict]:
@@ -848,7 +818,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def invite_to_group(
self,
destination: str,
@@ -868,7 +837,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def invite_to_group_notification(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -882,7 +850,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def remove_user_from_group(
self,
destination: str,
@@ -902,7 +869,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def remove_user_from_group_notification(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -916,7 +882,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def renew_group_attestation(
self, destination: str, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
@@ -930,7 +895,6 @@ class TransportLayerClient:
destination=destination, path=path, data=content, ignore_backoff=True
)
- @log_function
async def update_group_summary_room(
self,
destination: str,
@@ -959,7 +923,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def delete_group_summary_room(
self,
destination: str,
@@ -986,7 +949,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_group_categories(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -1000,7 +962,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_group_category(
self, destination: str, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
@@ -1014,7 +975,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def update_group_category(
self,
destination: str,
@@ -1034,7 +994,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def delete_group_category(
self, destination: str, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
@@ -1048,7 +1007,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_group_roles(
self, destination: str, group_id: str, requester_user_id: str
) -> JsonDict:
@@ -1062,7 +1020,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def get_group_role(
self, destination: str, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
@@ -1076,7 +1033,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def update_group_role(
self,
destination: str,
@@ -1096,7 +1052,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def delete_group_role(
self, destination: str, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
@@ -1110,7 +1065,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def update_group_summary_user(
self,
destination: str,
@@ -1136,7 +1090,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def set_group_join_policy(
self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
@@ -1151,7 +1104,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- @log_function
async def delete_group_summary_user(
self,
destination: str,
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 77bfd88ad..beadfa422 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -36,6 +36,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
class BaseFederationServerServlet(BaseFederationServlet):
@@ -95,6 +96,20 @@ class FederationSendServlet(BaseFederationServerServlet):
len(transaction_data.get("edus", [])),
)
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"}
+ device_list_updates = [
+ edu.content
+ for edu in transaction_data.get("edus", [])
+ if edu.edu_type in DEVICE_UPDATE_EDUS
+ ]
+ if device_list_updates:
+ issue_8631_logger.debug(
+ "received transaction [%s] including device list updates: %s",
+ transaction_id,
+ device_list_updates,
+ )
+
except Exception as e:
logger.exception(e)
return 400, {"error": "Invalid transaction"}
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 96273e2f8..bad48713b 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -77,7 +77,7 @@ class AccountDataHandler:
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
- """Add some account_data to a room for a user.
+ """Add some global account_data for a user.
Args:
user_id: The user to add a tag for.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 85157a138..00ab5e79b 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -55,21 +55,47 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
- ret = await self.store.get_user_by_id(user.to_string())
- if ret:
- profile = await self.store.get_profileinfo(user.localpart)
- threepids = await self.store.user_get_threepids(user.to_string())
- external_ids = [
- ({"auth_provider": auth_provider, "external_id": external_id})
- for auth_provider, external_id in await self.store.get_external_ids_by_user(
- user.to_string()
- )
- ]
- ret["displayname"] = profile.display_name
- ret["avatar_url"] = profile.avatar_url
- ret["threepids"] = threepids
- ret["external_ids"] = external_ids
- return ret
+ user_info_dict = await self.store.get_user_by_id(user.to_string())
+ if user_info_dict is None:
+ return None
+
+ # Restrict returned information to a known set of fields. This prevents additional
+ # fields added to get_user_by_id from modifying Synapse's external API surface.
+ user_info_to_return = {
+ "name",
+ "admin",
+ "deactivated",
+ "shadow_banned",
+ "creation_ts",
+ "appservice_id",
+ "consent_server_notice_sent",
+ "consent_version",
+ "user_type",
+ "is_guest",
+ }
+
+ # Restrict returned keys to a known set.
+ user_info_dict = {
+ key: value
+ for key, value in user_info_dict.items()
+ if key in user_info_to_return
+ }
+
+ # Add additional user metadata
+ profile = await self.store.get_profileinfo(user.localpart)
+ threepids = await self.store.user_get_threepids(user.to_string())
+ external_ids = [
+ ({"auth_provider": auth_provider, "external_id": external_id})
+ for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ user.to_string()
+ )
+ ]
+ user_info_dict["displayname"] = profile.display_name
+ user_info_dict["avatar_url"] = profile.avatar_url
+ user_info_dict["threepids"] = threepids
+ user_info_dict["external_ids"] = external_ids
+
+ return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 84724b207..bd1a32256 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -168,25 +168,25 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SsoLoginExtraAttributes:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
- creation_time = attr.ib(type=int)
- extra_attributes = attr.ib(type=JsonDict)
+ creation_time: int
+ extra_attributes: JsonDict
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class LoginTokenAttributes:
"""Data we store in a short-term login token"""
- user_id = attr.ib(type=str)
+ user_id: str
- auth_provider_id = attr.ib(type=str)
+ auth_provider_id: str
"""The SSO Identity Provider that the user authenticated with, to get this token."""
- auth_provider_session_id = attr.ib(type=Optional[str])
+ auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
@@ -2281,7 +2281,7 @@ class PasswordAuthProvider:
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
- callback(user_id, device_id, access_token)
+ await callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 766542523..b184a48cb 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -948,8 +948,16 @@ class DeviceListUpdater:
devices = []
ignore_devices = True
else:
+ prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
+ user_id
+ )
cached_devices = await self.store.get_cached_devices_for_user(user_id)
- if cached_devices == {d["device_id"]: d for d in devices}:
+
+ # To ensure that a user with no devices is cached, we skip the resync only
+ # if we have a stream_id from previously writing a cache entry.
+ if prev_stream_id is not None and cached_devices == {
+ d["device_id"]: d for d in devices
+ }:
logging.info(
"Skipping device list resync for %s, as our cache matches already",
user_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 14360b4e4..d4dfddf63 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1321,14 +1321,14 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
return old_key == new_key_copy
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys."""
- signing_key_id = attr.ib(type=str)
- target_user_id = attr.ib(type=str)
- target_device_id = attr.ib(type=str)
- signature = attr.ib(type=JsonDict)
+ signing_key_id: str
+ target_user_id: str
+ target_device_id: str
+ signature: JsonDict
class SigningKeyEduUpdater:
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1b996c420..bac5de052 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -20,7 +20,6 @@ from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
-from synapse.logging.utils import log_function
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
@@ -43,7 +42,6 @@ class EventStreamHandler:
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @log_function
async def get_stream(
self,
auth_user_id: str,
@@ -119,7 +117,7 @@ class EventStreamHandler:
events.extend(to_add)
- chunks = await self._event_serializer.serialize_events(
+ chunks = self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 26b8e3f43..a37ae0ca0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -51,7 +51,6 @@ from synapse.logging.context import (
preserve_fn,
run_in_background,
)
-from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
@@ -556,7 +555,6 @@ class FederationHandler:
run_in_background(self._handle_queued_pdus, room_queue)
- @log_function
async def do_knock(
self,
target_hosts: List[str],
@@ -928,7 +926,6 @@ class FederationHandler:
return event
- @log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
@@ -1039,7 +1036,6 @@ class FederationHandler:
else:
return []
- @log_function
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
@@ -1056,7 +1052,6 @@ class FederationHandler:
return events
- @log_function
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
@@ -1118,7 +1113,6 @@ class FederationHandler:
return missing_events
- @log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
) -> None:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 11771f3c9..3905f60b3 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -56,7 +56,6 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import nested_logging_context, run_in_background
-from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
@@ -275,7 +274,6 @@ class FederationEventHandler:
await self._process_received_pdu(origin, pdu, state=None)
- @log_function
async def on_send_membership_event(
self, origin: str, event: EventBase
) -> Tuple[EventBase, EventContext]:
@@ -472,7 +470,6 @@ class FederationEventHandler:
return await self.persist_events_and_notify(room_id, [(event, context)])
- @log_function
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> None:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 601bab67f..346a06ff4 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -170,7 +170,7 @@ class InitialSyncHandler:
d["inviter"] = event.sender
invite_event = await self.store.get_event(event.event_id)
- d["invite"] = await self._event_serializer.serialize_event(
+ d["invite"] = self._event_serializer.serialize_event(
invite_event,
time_now,
as_client_event=as_client_event,
@@ -222,7 +222,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
- await self._event_serializer.serialize_events(
+ self._event_serializer.serialize_events(
messages,
time_now=time_now,
as_client_event=as_client_event,
@@ -232,7 +232,7 @@ class InitialSyncHandler:
"end": await end_token.to_string(self.store),
}
- d["state"] = await self._event_serializer.serialize_events(
+ d["state"] = self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
@@ -376,16 +376,14 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(messages, time_now)
+ self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- room_state.values(), time_now
- )
+ self._event_serializer.serialize_events(room_state.values(), time_now)
),
"presence": [],
"receipts": [],
@@ -404,7 +402,7 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
- state = await self._event_serializer.serialize_events(
+ state = self._event_serializer.serialize_events(
current_state.values(), time_now
)
@@ -480,7 +478,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(messages, time_now)
+ self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5e3d3886e..b37250aa3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -246,7 +246,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
- events = await self._event_serializer.serialize_events(room_state.values(), now)
+ events = self._event_serializer.serialize_events(room_state.values(), now)
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7469cc55a..973f26296 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -537,14 +537,16 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
+ aggregations = await self.store.get_bundled_aggregations(events, user_id)
+
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- await self._event_serializer.serialize_events(
+ self._event_serializer.serialize_events(
events,
time_now,
- bundle_aggregations=True,
+ bundle_aggregations=aggregations,
as_client_event=as_client_event,
)
),
@@ -553,7 +555,7 @@ class PaginationHandler:
}
if state:
- chunk["state"] = await self._event_serializer.serialize_events(
+ chunk["state"] = self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index c781fefb1..067c43ae4 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -55,7 +55,6 @@ from synapse.api.presence import UserPresenceState
from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
-from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.presence import (
@@ -1542,7 +1541,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @log_function
async def get_new_events(
self,
user: UserID,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9c1cbffa..f963078e5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -393,7 +393,9 @@ class RoomCreationHandler:
user_id = requester.user.to_string()
if not await self.spam_checker.user_may_create_room(user_id):
- raise SynapseError(403, "You are not permitted to create rooms")
+ raise SynapseError(
+ 403, "You are not permitted to create rooms", Codes.FORBIDDEN
+ )
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
@@ -685,7 +687,9 @@ class RoomCreationHandler:
invite_3pid_list,
)
):
- raise SynapseError(403, "You are not permitted to create rooms")
+ raise SynapseError(
+ 403, "You are not permitted to create rooms", Codes.FORBIDDEN
+ )
if ratelimit:
await self.request_ratelimiter.ratelimit(requester)
@@ -1177,6 +1181,22 @@ class RoomContextHandler:
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]
+ # Fetch the aggregations.
+ aggregations = await self.store.get_bundled_aggregations(
+ [results["event"]], user.to_string()
+ )
+ aggregations.update(
+ await self.store.get_bundled_aggregations(
+ results["events_before"], user.to_string()
+ )
+ )
+ aggregations.update(
+ await self.store.get_bundled_aggregations(
+ results["events_after"], user.to_string()
+ )
+ )
+ results["aggregations"] = aggregations
+
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
else:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index b2cfe537d..4844b69a0 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -153,6 +153,9 @@ class RoomSummaryHandler:
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
+ if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
+ max_rooms_per_space = MAX_ROOMS_PER_SPACE
+
while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft()
room_id = queue_entry.room_id
@@ -167,7 +170,7 @@ class RoomSummaryHandler:
# The client-specified max_rooms_per_space limit doesn't apply to the
# room_id specified in the request, so we ignore it if this is the
# first room we are processing.
- max_children = max_rooms_per_space if processed_rooms else None
+ max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS
if is_in_room:
room_entry = await self._summarize_local_room(
@@ -209,7 +212,7 @@ class RoomSummaryHandler:
# Before returning to the client, remove the allowed_room_ids
# and allowed_spaces keys.
room.pop("allowed_room_ids", None)
- room.pop("allowed_spaces", None)
+ room.pop("allowed_spaces", None) # historical
rooms_result.append(room)
events.extend(room_entry.children_state_events)
@@ -395,7 +398,7 @@ class RoomSummaryHandler:
None,
room_id,
suggested_only,
- # TODO Handle max children.
+ # Do not limit the maximum children.
max_children=None,
)
@@ -525,6 +528,10 @@ class RoomSummaryHandler:
rooms_result: List[JsonDict] = []
events_result: List[JsonDict] = []
+ # Set a limit on the number of rooms to return.
+ if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
+ max_rooms_per_space = MAX_ROOMS_PER_SPACE
+
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
if room_id in processed_rooms:
@@ -583,7 +590,9 @@ class RoomSummaryHandler:
# Iterate through each child and potentially add it, but not its children,
# to the response.
- for child_room in root_room_entry.children_state_events:
+ for child_room in itertools.islice(
+ root_room_entry.children_state_events, MAX_ROOMS_PER_SPACE
+ ):
room_id = child_room.get("state_key")
assert isinstance(room_id, str)
# If the room is unknown, skip it.
@@ -633,8 +642,8 @@ class RoomSummaryHandler:
suggested_only: True if only suggested children should be returned.
Otherwise, all children are returned.
max_children:
- The maximum number of children rooms to include. This is capped
- to a server-set limit.
+ The maximum number of children rooms to include. A value of None
+ means no limit.
Returns:
A room entry if the room should be returned. None, otherwise.
@@ -656,8 +665,13 @@ class RoomSummaryHandler:
# we only care about suggested children
child_events = filter(_is_suggested_child_event, child_events)
- if max_children is None or max_children > MAX_ROOMS_PER_SPACE:
- max_children = MAX_ROOMS_PER_SPACE
+ # TODO max_children is legacy code for the /spaces endpoint.
+ if max_children is not None:
+ child_iter: Iterable[EventBase] = itertools.islice(
+ child_events, max_children
+ )
+ else:
+ child_iter = child_events
stripped_events: List[JsonDict] = [
{
@@ -668,7 +682,7 @@ class RoomSummaryHandler:
"sender": e.sender,
"origin_server_ts": e.origin_server_ts,
}
- for e in itertools.islice(child_events, max_children)
+ for e in child_iter
]
return _RoomEntry(room_id, room_entry, stripped_events)
@@ -766,6 +780,7 @@ class RoomSummaryHandler:
try:
(
room_response,
+ children_state_events,
children,
inaccessible_children,
) = await self._federation_client.get_room_hierarchy(
@@ -790,7 +805,7 @@ class RoomSummaryHandler:
}
return (
- _RoomEntry(room_id, room_response, room_response.pop("children_state", ())),
+ _RoomEntry(room_id, room_response, children_state_events),
children_by_room_id,
set(inaccessible_children),
)
@@ -988,12 +1003,14 @@ class RoomSummaryHandler:
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
+ # plural join_rules is a documentation error but kept for historical
+ # purposes. Should match /publicRooms.
"join_rules": stats["join_rules"],
+ "join_rule": stats["join_rules"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
"guest_can_join": stats["guest_access"] == "can_join",
- "creation_ts": create_event.origin_server_ts,
"room_type": create_event.content.get(EventContentFields.ROOM_TYPE),
}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ab7eaab2f..0b153a682 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -420,10 +420,10 @@ class SearchHandler:
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = await self._event_serializer.serialize_events(
+ context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
)
- context["events_after"] = await self._event_serializer.serialize_events(
+ context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
)
@@ -441,9 +441,7 @@ class SearchHandler:
results.append(
{
"rank": rank_map[e.event_id],
- "result": (
- await self._event_serializer.serialize_event(e, time_now)
- ),
+ "result": self._event_serializer.serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}),
}
)
@@ -457,7 +455,7 @@ class SearchHandler:
if state_results:
s = {}
for room_id, state_events in state_results.items():
- s[room_id] = await self._event_serializer.serialize_events(
+ s[room_id] = self._event_serializer.serialize_events(
state_events, time_now
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 65c27bc64..0bb8b0929 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -126,45 +126,45 @@ class SsoIdentityProvider(Protocol):
raise NotImplementedError()
-@attr.s
+@attr.s(auto_attribs=True)
class UserAttributes:
# the localpart of the mxid that the mapper has assigned to the user.
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.
- localpart = attr.ib(type=Optional[str])
- display_name = attr.ib(type=Optional[str], default=None)
- emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+ localpart: Optional[str]
+ display_name: Optional[str] = None
+ emails: Collection[str] = attr.Factory(list)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class UsernameMappingSession:
"""Data we track about SSO sessions"""
# A unique identifier for this SSO provider, e.g. "oidc" or "saml".
- auth_provider_id = attr.ib(type=str)
+ auth_provider_id: str
# user ID on the IdP server
- remote_user_id = attr.ib(type=str)
+ remote_user_id: str
# attributes returned by the ID mapper
- display_name = attr.ib(type=Optional[str])
- emails = attr.ib(type=Collection[str])
+ display_name: Optional[str]
+ emails: Collection[str]
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
- extra_login_attributes = attr.ib(type=Optional[JsonDict])
+ extra_login_attributes: Optional[JsonDict]
# where to redirect the client back to
- client_redirect_url = attr.ib(type=str)
+ client_redirect_url: str
# expiry time for the session, in milliseconds
- expiry_time_ms = attr.ib(type=int)
+ expiry_time_ms: int
# choices made by the user
- chosen_localpart = attr.ib(type=Optional[str], default=None)
- use_display_name = attr.ib(type=bool, default=True)
- emails_to_use = attr.ib(type=Collection[str], default=())
- terms_accepted_version = attr.ib(type=Optional[str], default=None)
+ chosen_localpart: Optional[str] = None
+ use_display_name: bool = True
+ emails_to_use: Collection[str] = ()
+ terms_accepted_version: Optional[str] = None
# the HTTP cookie used to track the mapping session id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7baf3f199..ffc6b748e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -60,10 +60,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-# Debug logger for https://github.com/matrix-org/synapse/issues/4422
-issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug")
-
-
# Counts the number of times we returned a non-empty sync. `type` is one of
# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is
# "true" or "false" depending on if the request asked for lazy loaded members or
@@ -102,6 +98,9 @@ class TimelineBatch:
prev_batch: StreamToken
events: List[EventBase]
limited: bool
+ # A mapping of event ID to the bundled aggregations for the above events.
+ # This is only calculated if limited is true.
+ bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -634,10 +633,19 @@ class SyncHandler:
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+ # Don't bother to bundle aggregations if the timeline is unlimited,
+ # as clients will have all the necessary information.
+ bundled_aggregations = None
+ if limited or newly_joined_room:
+ bundled_aggregations = await self.store.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
+
return TimelineBatch(
events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room,
+ bundled_aggregations=bundled_aggregations,
)
async def get_state_after_event(
@@ -1161,13 +1169,8 @@ class SyncHandler:
num_events = 0
- # debug for https://github.com/matrix-org/synapse/issues/4422
+ # debug for https://github.com/matrix-org/synapse/issues/9424
for joined_room in sync_result_builder.joined:
- room_id = joined_room.room_id
- if room_id in newly_joined_rooms:
- issue4422_logger.debug(
- "Sync result for newly joined room %s: %r", room_id, joined_room
- )
num_events += len(joined_room.timeline.events)
log_kv(
@@ -1740,18 +1743,6 @@ class SyncHandler:
old_mem_ev_id, allow_none=True
)
- # debug for #4422
- if has_join:
- prev_membership = None
- if old_mem_ev:
- prev_membership = old_mem_ev.membership
- issue4422_logger.debug(
- "Previous membership for room %s with join: %s (event %s)",
- room_id,
- prev_membership,
- old_mem_ev_id,
- )
-
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
@@ -1893,13 +1884,6 @@ class SyncHandler:
upto_token=since_token,
)
- if newly_joined:
- # debugging for https://github.com/matrix-org/synapse/issues/4422
- issue4422_logger.debug(
- "RoomSyncResultBuilder events for newly joined room %s: %r",
- room_id,
- entry.events,
- )
room_entries.append(entry)
return _RoomChanges(
@@ -2077,14 +2061,6 @@ class SyncHandler:
# `_load_filtered_recents` can't find any events the user should see
# (e.g. due to having ignored the sender of the last 50 events).
- if newly_joined:
- # debug for https://github.com/matrix-org/synapse/issues/4422
- issue4422_logger.debug(
- "Timeline events after filtering in newly-joined room %s: %r",
- room_id,
- batch,
- )
-
# When we join the room (or the client requests full_state), we should
# send down any existing tags. Usually the user won't have tags in a
# newly joined room, unless either a) they've joined before or b) the
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index fbafffd69..203e995bb 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -32,9 +32,9 @@ class ProxyConnectError(ConnectError):
pass
-@attr.s
+@attr.s(auto_attribs=True)
class ProxyCredentials:
- username_password = attr.ib(type=bytes)
+ username_password: bytes
def as_proxy_authorization_value(self) -> bytes:
"""
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index deedde0b5..2e668363b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -123,37 +123,37 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
pass
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class MatrixFederationRequest:
- method = attr.ib(type=str)
+ method: str
"""HTTP method
"""
- path = attr.ib(type=str)
+ path: str
"""HTTP path
"""
- destination = attr.ib(type=str)
+ destination: str
"""The remote server to send the HTTP request to.
"""
- json = attr.ib(default=None, type=Optional[JsonDict])
+ json: Optional[JsonDict] = None
"""JSON to send in the body.
"""
- json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
+ json_callback: Optional[Callable[[], JsonDict]] = None
"""A callback to generate the JSON.
"""
- query = attr.ib(default=None, type=Optional[dict])
+ query: Optional[dict] = None
"""Query arguments.
"""
- txn_id = attr.ib(default=None, type=Optional[str])
+ txn_id: Optional[str] = None
"""Unique ID for this request (for logging)
"""
- uri = attr.ib(init=False, type=bytes)
+ uri: bytes = attr.ib(init=False)
"""The URI of this request
"""
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 80f7a2ff5..c180a1d32 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -534,9 +534,9 @@ class XForwardedForRequest(SynapseRequest):
@implementer(IAddress)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class _XForwardedForAddress:
- host = attr.ib(type=str)
+ host: str
class SynapseSite(Site):
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 8202d0494..475756f1d 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -39,7 +39,7 @@ from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True, auto_attribs=True)
@implementer(IPushProducer)
class LogProducer:
"""
@@ -54,10 +54,10 @@ class LogProducer:
# This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
- transport = attr.ib(type=Connection)
- _format = attr.ib(type=Callable[[logging.LogRecord], str])
- _buffer = attr.ib(type=deque)
- _paused = attr.ib(default=False, type=bool, init=False)
+ transport: Connection
+ _format: Callable[[logging.LogRecord], str]
+ _buffer: Deque[logging.LogRecord]
+ _paused: bool = attr.ib(default=False, init=False)
def pauseProducing(self):
self._paused = True
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index d4ee89337..c31c2960a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -193,7 +193,7 @@ class ContextResourceUsage:
return res
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ContextRequest:
"""
A bundle of attributes from the SynapseRequest object.
@@ -205,15 +205,15 @@ class ContextRequest:
their children.
"""
- request_id = attr.ib(type=str)
- ip_address = attr.ib(type=str)
- site_tag = attr.ib(type=str)
- requester = attr.ib(type=Optional[str])
- authenticated_entity = attr.ib(type=Optional[str])
- method = attr.ib(type=str)
- url = attr.ib(type=str)
- protocol = attr.ib(type=str)
- user_agent = attr.ib(type=str)
+ request_id: str
+ ip_address: str
+ site_tag: str
+ requester: Optional[str]
+ authenticated_entity: Optional[str]
+ method: str
+ url: str
+ protocol: str
+ user_agent: str
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 622445e9f..b240d2d21 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -247,11 +247,11 @@ try:
class BaseReporter: # type: ignore[no-redef]
pass
- @attr.s(slots=True, frozen=True)
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
class _WrappedRustReporter(BaseReporter):
"""Wrap the reporter to ensure `report_span` never throws."""
- _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
+ _reporter: Reporter = attr.Factory(Reporter)
def set_process(self, *args, **kwargs):
return self._reporter.set_process(*args, **kwargs)
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
deleted file mode 100644
index 4a01b902c..000000000
--- a/synapse/logging/utils.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-from functools import wraps
-from inspect import getcallargs
-from typing import Callable, TypeVar, cast
-
-_TIME_FUNC_ID = 0
-
-
-def _log_debug_as_f(f, msg, msg_args):
- name = f.__module__
- logger = logging.getLogger(name)
-
- if logger.isEnabledFor(logging.DEBUG):
- lineno = f.__code__.co_firstlineno
- pathname = f.__code__.co_filename
-
- record = logger.makeRecord(
- name=name,
- level=logging.DEBUG,
- fn=pathname,
- lno=lineno,
- msg=msg,
- args=msg_args,
- exc_info=None,
- )
-
- logger.handle(record)
-
-
-F = TypeVar("F", bound=Callable)
-
-
-def log_function(f: F) -> F:
- """Function decorator that logs every call to that function."""
- func_name = f.__name__
-
- @wraps(f)
- def wrapped(*args, **kwargs):
- name = f.__module__
- logger = logging.getLogger(name)
- level = logging.DEBUG
-
- if logger.isEnabledFor(level):
- bound_args = getcallargs(f, *args, **kwargs)
-
- def format(value):
- r = str(value)
- if len(r) > 50:
- r = r[:50] + "..."
- return r
-
- func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
-
- msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
-
- _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
-
- return f(*args, **kwargs)
-
- wrapped.__name__ = func_name
- return cast(F, wrapped)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index ceef57ad8..9e6c1b2f3 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -12,16 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import functools
-import gc
import itertools
import logging
import os
import platform
import threading
-import time
from typing import (
- Any,
Callable,
Dict,
Generic,
@@ -34,35 +30,31 @@ from typing import (
Type,
TypeVar,
Union,
- cast,
)
import attr
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
from prometheus_client.core import (
REGISTRY,
- CounterMetricFamily,
GaugeHistogramMetricFamily,
GaugeMetricFamily,
)
-from twisted.internet import reactor
-from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool
-import synapse
+import synapse.metrics._reactor_metrics
from synapse.metrics._exposition import (
MetricsResource,
generate_latest,
start_http_server,
)
+from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
-running_on_pypy = platform.python_implementation() == "PyPy"
all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -76,19 +68,17 @@ class RegistryProxy:
yield metric
-@attr.s(slots=True, hash=True)
+@attr.s(slots=True, hash=True, auto_attribs=True)
class LaterGauge:
- name = attr.ib(type=str)
- desc = attr.ib(type=str)
- labels = attr.ib(hash=False, type=Optional[Iterable[str]])
+ name: str
+ desc: str
+ labels: Optional[Iterable[str]] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
- caller = attr.ib(
- type=Callable[
- [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
- ]
- )
+ caller: Callable[
+ [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
+ ]
def collect(self) -> Iterable[Metric]:
@@ -157,7 +147,9 @@ class InFlightGauge(Generic[MetricsEntry]):
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class: Type[MetricsEntry] = attr.make_class(
- "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
+ "_MetricsEntry",
+ attrs={x: attr.ib(default=0) for x in sub_metrics},
+ slots=True,
)
# Counts number of in flight blocks for a given set of label values
@@ -369,136 +361,6 @@ class CPUMetrics:
REGISTRY.register(CPUMetrics())
-#
-# Python GC metrics
-#
-
-gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
-gc_time = Histogram(
- "python_gc_time",
- "Time taken to GC (sec)",
- ["gen"],
- buckets=[
- 0.0025,
- 0.005,
- 0.01,
- 0.025,
- 0.05,
- 0.10,
- 0.25,
- 0.50,
- 1.00,
- 2.50,
- 5.00,
- 7.50,
- 15.00,
- 30.00,
- 45.00,
- 60.00,
- ],
-)
-
-
-class GCCounts:
- def collect(self) -> Iterable[Metric]:
- cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
- for n, m in enumerate(gc.get_count()):
- cm.add_metric([str(n)], m)
-
- yield cm
-
-
-if not running_on_pypy:
- REGISTRY.register(GCCounts())
-
-
-#
-# PyPy GC / memory metrics
-#
-
-
-class PyPyGCStats:
- def collect(self) -> Iterable[Metric]:
-
- # @stats is a pretty-printer object with __str__() returning a nice table,
- # plus some fields that contain data from that table.
- # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
- stats = gc.get_stats(memory_pressure=False) # type: ignore
- # @s contains same fields as @stats, but as actual integers.
- s = stats._s # type: ignore
-
- # also note that field naming is completely braindead
- # and only vaguely correlates with the pretty-printed table.
- # >>>> gc.get_stats(False)
- # Total memory consumed:
- # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
- # in arenas: 3.0MB # s.total_arena_memory
- # rawmalloced: 1.7MB # s.total_rawmalloced_memory
- # nursery: 4.0MB # s.nursery_size
- # raw assembler used: 31.0kB # s.jit_backend_used
- # -----------------------------
- # Total: 8.8MB # stats.memory_used_sum
- #
- # Total memory allocated:
- # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
- # in arenas: 30.9MB # s.peak_arena_memory
- # rawmalloced: 4.1MB # s.peak_rawmalloced_memory
- # nursery: 4.0MB # s.nursery_size
- # raw assembler allocated: 1.0MB # s.jit_backend_allocated
- # -----------------------------
- # Total: 39.7MB # stats.memory_allocated_sum
- #
- # Total time spent in GC: 0.073 # s.total_gc_time
-
- pypy_gc_time = CounterMetricFamily(
- "pypy_gc_time_seconds_total",
- "Total time spent in PyPy GC",
- labels=[],
- )
- pypy_gc_time.add_metric([], s.total_gc_time / 1000)
- yield pypy_gc_time
-
- pypy_mem = GaugeMetricFamily(
- "pypy_memory_bytes",
- "Memory tracked by PyPy allocator",
- labels=["state", "class", "kind"],
- )
- # memory used by JIT assembler
- pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
- pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
- # memory used by GCed objects
- pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
- pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
- pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
- pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
- pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
- pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
- # totals
- pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
- pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
- pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
- pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
- yield pypy_mem
-
-
-if running_on_pypy:
- REGISTRY.register(PyPyGCStats())
-
-
-#
-# Twisted reactor metrics
-#
-
-tick_time = Histogram(
- "python_twisted_reactor_tick_time",
- "Tick time of the Twisted reactor (sec)",
- buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
-)
-pending_calls_metric = Histogram(
- "python_twisted_reactor_pending_calls",
- "Pending calls",
- buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000],
-)
#
# Federation Metrics
@@ -551,8 +413,6 @@ build_info.labels(
" ".join([platform.system(), platform.release()]),
).set(1)
-last_ticked = time.time()
-
# 3PID send info
threepid_send_requests = Histogram(
"synapse_threepid_send_requests_with_tries",
@@ -600,116 +460,6 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
)
-class ReactorLastSeenMetric:
- def collect(self) -> Iterable[Metric]:
- cm = GaugeMetricFamily(
- "python_twisted_reactor_last_seen",
- "Seconds since the Twisted reactor was last seen",
- )
- cm.add_metric([], time.time() - last_ticked)
- yield cm
-
-
-REGISTRY.register(ReactorLastSeenMetric())
-
-# The minimum time in seconds between GCs for each generation, regardless of the current GC
-# thresholds and counts.
-MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
-
-# The time (in seconds since the epoch) of the last time we did a GC for each generation.
-_last_gc = [0.0, 0.0, 0.0]
-
-
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
- @functools.wraps(func)
- def f(*args: Any, **kwargs: Any) -> Any:
- now = reactor.seconds()
- num_pending = 0
-
- # _newTimedCalls is one long list of *all* pending calls. Below loop
- # is based off of impl of reactor.runUntilCurrent
- for delayed_call in reactor._newTimedCalls:
- if delayed_call.time > now:
- break
-
- if delayed_call.delayed_time > 0:
- continue
-
- num_pending += 1
-
- num_pending += len(reactor.threadCallQueue)
- start = time.time()
- ret = func(*args, **kwargs)
- end = time.time()
-
- # record the amount of wallclock time spent running pending calls.
- # This is a proxy for the actual amount of time between reactor polls,
- # since about 25% of time is actually spent running things triggered by
- # I/O events, but that is harder to capture without rewriting half the
- # reactor.
- tick_time.observe(end - start)
- pending_calls_metric.observe(num_pending)
-
- # Update the time we last ticked, for the metric to test whether
- # Synapse's reactor has frozen
- global last_ticked
- last_ticked = end
-
- if running_on_pypy:
- return ret
-
- # Check if we need to do a manual GC (since its been disabled), and do
- # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
- # promote an object into gen 2, and we don't want to handle the same
- # object multiple times.
- threshold = gc.get_threshold()
- counts = gc.get_count()
- for i in (2, 1, 0):
- # We check if we need to do one based on a straightforward
- # comparison between the threshold and count. We also do an extra
- # check to make sure that we don't a GC too often.
- if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
- if i == 0:
- logger.debug("Collecting gc %d", i)
- else:
- logger.info("Collecting gc %d", i)
-
- start = time.time()
- unreachable = gc.collect(i)
- end = time.time()
-
- _last_gc[i] = end
-
- gc_time.labels(i).observe(end - start)
- gc_unreachable.labels(i).set(unreachable)
-
- return ret
-
- return cast(F, f)
-
-
-try:
- # Ensure the reactor has all the attributes we expect
- reactor.seconds # type: ignore
- reactor.runUntilCurrent # type: ignore
- reactor._newTimedCalls # type: ignore
- reactor.threadCallQueue # type: ignore
-
- # runUntilCurrent is called when we have pending calls. It is called once
- # per iteratation after fd polling.
- reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
-
- # We manually run the GC each reactor tick so that we can get some metrics
- # about time spent doing GC,
- if not running_on_pypy:
- gc.disable()
-except AttributeError:
- pass
-
-
__all__ = [
"MetricsResource",
"generate_latest",
@@ -717,4 +467,6 @@ __all__ = [
"LaterGauge",
"InFlightGauge",
"GaugeBucketCollector",
+ "MIN_TIME_BETWEEN_GCS",
+ "install_gc_manager",
]
diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py
new file mode 100644
index 000000000..2bc909efa
--- /dev/null
+++ b/synapse/metrics/_gc.py
@@ -0,0 +1,203 @@
+# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import gc
+import logging
+import platform
+import time
+from typing import Iterable
+
+from prometheus_client.core import (
+ REGISTRY,
+ CounterMetricFamily,
+ Gauge,
+ GaugeMetricFamily,
+ Histogram,
+ Metric,
+)
+
+from twisted.internet import task
+
+"""Prometheus metrics for garbage collection"""
+
+
+logger = logging.getLogger(__name__)
+
+# The minimum time in seconds between GCs for each generation, regardless of the current GC
+# thresholds and counts.
+MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
+
+running_on_pypy = platform.python_implementation() == "PyPy"
+
+#
+# Python GC metrics
+#
+
+gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
+gc_time = Histogram(
+ "python_gc_time",
+ "Time taken to GC (sec)",
+ ["gen"],
+ buckets=[
+ 0.0025,
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.10,
+ 0.25,
+ 0.50,
+ 1.00,
+ 2.50,
+ 5.00,
+ 7.50,
+ 15.00,
+ 30.00,
+ 45.00,
+ 60.00,
+ ],
+)
+
+
+class GCCounts:
+ def collect(self) -> Iterable[Metric]:
+ cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
+ for n, m in enumerate(gc.get_count()):
+ cm.add_metric([str(n)], m)
+
+ yield cm
+
+
+def install_gc_manager() -> None:
+ """Disable automatic GC, and replace it with a task that runs every 100ms
+
+ This means that (a) we can limit how often GC runs; (b) we can get some metrics
+ about GC activity.
+
+ It does nothing on PyPy.
+ """
+
+ if running_on_pypy:
+ return
+
+ REGISTRY.register(GCCounts())
+
+ gc.disable()
+
+ # The time (in seconds since the epoch) of the last time we did a GC for each generation.
+ _last_gc = [0.0, 0.0, 0.0]
+
+ def _maybe_gc() -> None:
+ # Check if we need to do a manual GC (since its been disabled), and do
+ # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
+ # promote an object into gen 2, and we don't want to handle the same
+ # object multiple times.
+ threshold = gc.get_threshold()
+ counts = gc.get_count()
+ end = time.time()
+ for i in (2, 1, 0):
+ # We check if we need to do one based on a straightforward
+ # comparison between the threshold and count. We also do an extra
+ # check to make sure that we don't a GC too often.
+ if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
+ if i == 0:
+ logger.debug("Collecting gc %d", i)
+ else:
+ logger.info("Collecting gc %d", i)
+
+ start = time.time()
+ unreachable = gc.collect(i)
+ end = time.time()
+
+ _last_gc[i] = end
+
+ gc_time.labels(i).observe(end - start)
+ gc_unreachable.labels(i).set(unreachable)
+
+ gc_task = task.LoopingCall(_maybe_gc)
+ gc_task.start(0.1)
+
+
+#
+# PyPy GC / memory metrics
+#
+
+
+class PyPyGCStats:
+ def collect(self) -> Iterable[Metric]:
+
+ # @stats is a pretty-printer object with __str__() returning a nice table,
+ # plus some fields that contain data from that table.
+ # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
+ stats = gc.get_stats(memory_pressure=False) # type: ignore
+ # @s contains same fields as @stats, but as actual integers.
+ s = stats._s # type: ignore
+
+ # also note that field naming is completely braindead
+ # and only vaguely correlates with the pretty-printed table.
+ # >>>> gc.get_stats(False)
+ # Total memory consumed:
+ # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
+ # in arenas: 3.0MB # s.total_arena_memory
+ # rawmalloced: 1.7MB # s.total_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler used: 31.0kB # s.jit_backend_used
+ # -----------------------------
+ # Total: 8.8MB # stats.memory_used_sum
+ #
+ # Total memory allocated:
+ # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
+ # in arenas: 30.9MB # s.peak_arena_memory
+ # rawmalloced: 4.1MB # s.peak_rawmalloced_memory
+ # nursery: 4.0MB # s.nursery_size
+ # raw assembler allocated: 1.0MB # s.jit_backend_allocated
+ # -----------------------------
+ # Total: 39.7MB # stats.memory_allocated_sum
+ #
+ # Total time spent in GC: 0.073 # s.total_gc_time
+
+ pypy_gc_time = CounterMetricFamily(
+ "pypy_gc_time_seconds_total",
+ "Total time spent in PyPy GC",
+ labels=[],
+ )
+ pypy_gc_time.add_metric([], s.total_gc_time / 1000)
+ yield pypy_gc_time
+
+ pypy_mem = GaugeMetricFamily(
+ "pypy_memory_bytes",
+ "Memory tracked by PyPy allocator",
+ labels=["state", "class", "kind"],
+ )
+ # memory used by JIT assembler
+ pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
+ pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
+ # memory used by GCed objects
+ pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
+ pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
+ pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
+ pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
+ pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
+ pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
+ # totals
+ pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
+ pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
+ pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
+ yield pypy_mem
+
+
+if running_on_pypy:
+ REGISTRY.register(PyPyGCStats())
diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
new file mode 100644
index 000000000..f38f79831
--- /dev/null
+++ b/synapse/metrics/_reactor_metrics.py
@@ -0,0 +1,83 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import select
+import time
+from typing import Any, Iterable, List, Tuple
+
+from prometheus_client import Histogram, Metric
+from prometheus_client.core import REGISTRY, GaugeMetricFamily
+
+from twisted.internet import reactor
+
+#
+# Twisted reactor metrics
+#
+
+tick_time = Histogram(
+ "python_twisted_reactor_tick_time",
+ "Tick time of the Twisted reactor (sec)",
+ buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
+)
+
+
+class EpollWrapper:
+ """a wrapper for an epoll object which records the time between polls"""
+
+ def __init__(self, poller: "select.epoll"): # type: ignore[name-defined]
+ self.last_polled = time.time()
+ self._poller = poller
+
+ def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def]
+ # record the time since poll() was last called. This gives a good proxy for
+ # how long it takes to run everything in the reactor - ie, how long anything
+ # waiting for the next tick will have to wait.
+ tick_time.observe(time.time() - self.last_polled)
+
+ ret = self._poller.poll(*args, **kwargs)
+
+ self.last_polled = time.time()
+ return ret
+
+ def __getattr__(self, item: str) -> Any:
+ return getattr(self._poller, item)
+
+
+class ReactorLastSeenMetric:
+ def __init__(self, epoll_wrapper: EpollWrapper):
+ self._epoll_wrapper = epoll_wrapper
+
+ def collect(self) -> Iterable[Metric]:
+ cm = GaugeMetricFamily(
+ "python_twisted_reactor_last_seen",
+ "Seconds since the Twisted reactor was last seen",
+ )
+ cm.add_metric([], time.time() - self._epoll_wrapper.last_polled)
+ yield cm
+
+
+try:
+ # if the reactor has a `_poller` attribute, which is an `epoll` object
+ # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will
+ # measure the time between ticks
+ from select import epoll # type: ignore[attr-defined]
+
+ poller = reactor._poller # type: ignore[attr-defined]
+except (AttributeError, ImportError):
+ pass
+else:
+ if isinstance(poller, epoll):
+ poller = EpollWrapper(poller)
+ reactor._poller = poller # type: ignore[attr-defined]
+ REGISTRY.register(ReactorLastSeenMetric(poller))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index bbabdb058..632b2245e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -40,7 +40,6 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging import issue9533_logger
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import log_kv, start_active_span
-from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import (
@@ -193,15 +192,15 @@ class EventStreamResult:
return bool(self.events)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PendingRoomEventEntry:
- event_pos = attr.ib(type=PersistedEventPosition)
- extra_users = attr.ib(type=Collection[UserID])
+ event_pos: PersistedEventPosition
+ extra_users: Collection[UserID]
- room_id = attr.ib(type=str)
- type = attr.ib(type=str)
- state_key = attr.ib(type=Optional[str])
- membership = attr.ib(type=Optional[str])
+ room_id: str
+ type: str
+ state_key: Optional[str]
+ membership: Optional[str]
class Notifier:
@@ -686,7 +685,6 @@ class Notifier:
else:
return False
- @log_function
def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec()
expired_streams = []
@@ -700,7 +698,6 @@ class Notifier:
for expired_stream in expired_streams:
expired_stream.remove(self)
- @log_function
def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 820f6f3f7..5176a1c18 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -23,25 +23,25 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PusherConfig:
"""Parameters necessary to configure a pusher."""
- id = attr.ib(type=Optional[str])
- user_name = attr.ib(type=str)
- access_token = attr.ib(type=Optional[int])
- profile_tag = attr.ib(type=str)
- kind = attr.ib(type=str)
- app_id = attr.ib(type=str)
- app_display_name = attr.ib(type=str)
- device_display_name = attr.ib(type=str)
- pushkey = attr.ib(type=str)
- ts = attr.ib(type=int)
- lang = attr.ib(type=Optional[str])
- data = attr.ib(type=Optional[JsonDict])
- last_stream_ordering = attr.ib(type=int)
- last_success = attr.ib(type=Optional[int])
- failing_since = attr.ib(type=Optional[int])
+ id: Optional[str]
+ user_name: str
+ access_token: Optional[int]
+ profile_tag: str
+ kind: str
+ app_id: str
+ app_display_name: str
+ device_display_name: str
+ pushkey: str
+ ts: int
+ lang: Optional[str]
+ data: Optional[JsonDict]
+ last_stream_ordering: int
+ last_success: Optional[int]
+ failing_since: Optional[int]
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
@@ -57,12 +57,12 @@ class PusherConfig:
}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ThrottleParams:
"""Parameters for controlling the rate of sending pushes via email."""
- last_sent_ts = attr.ib(type=int)
- throttle_ms = attr.ib(type=int)
+ last_sent_ts: int
+ throttle_ms: int
class Pusher(metaclass=abc.ABCMeta):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 009d8e77b..bee660893 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -298,7 +298,7 @@ RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class RulesForRoomData:
"""The data stored in the cache by `RulesForRoom`.
@@ -307,29 +307,29 @@ class RulesForRoomData:
"""
# event_id -> (user_id, state)
- member_map = attr.ib(type=MemberMap, factory=dict)
+ member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
- rules_by_user = attr.ib(type=RulesByUser, factory=dict)
+ rules_by_user: RulesByUser = attr.Factory(dict)
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
- state_group = attr.ib(type=StateGroup, factory=object)
+ state_group: StateGroup = attr.Factory(object)
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
- sequence = attr.ib(type=int, default=0)
+ sequence: int = 0
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- uninteresting_user_set = attr.ib(type=Set[str], factory=set)
+ uninteresting_user_set: Set[str] = attr.Factory(set)
class RulesForRoom:
@@ -553,7 +553,7 @@ class RulesForRoom:
self.data.state_group = state_group
-@attr.attrs(slots=True, frozen=True)
+@attr.attrs(slots=True, frozen=True, auto_attribs=True)
class _Invalidation:
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
@@ -564,8 +564,8 @@ class _Invalidation:
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
# set `frozen=True`.
- cache = attr.ib(type=LruCache)
- room_id = attr.ib(type=str)
+ cache: LruCache
+ room_id: str
def __call__(self) -> None:
rules_data = self.cache.get(self.room_id, None, update_metrics=False)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ff904c2b4..dadfc5741 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -178,7 +178,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.email_validation
- % {"server_name": self.hs.config.server.server_name},
+ % {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)
@@ -209,7 +209,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.email_validation
- % {"server_name": self.hs.config.server.server_name},
+ % {"server_name": self.hs.config.server.server_name, "app": self.app_name},
template_vars,
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a390cfcb7..4f4f1ad45 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -50,12 +50,12 @@ data part are:
"""
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow:
"""A parsed row from the events replication stream"""
- type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
- data = attr.ib() # BaseEventsStreamRow
+ type: str # the TypeId of one of the *EventsStreamRows
+ data: "BaseEventsStreamRow"
class BaseEventsStreamRow:
@@ -79,28 +79,28 @@ class BaseEventsStreamRow:
return cls(*data)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib(type=str)
- room_id = attr.ib(type=str)
- type = attr.ib(type=str)
- state_key = attr.ib(type=Optional[str])
- redacts = attr.ib(type=Optional[str])
- relates_to = attr.ib(type=Optional[str])
- membership = attr.ib(type=Optional[str])
- rejected = attr.ib(type=bool)
+ event_id: str
+ room_id: str
+ type: str
+ state_key: Optional[str]
+ redacts: Optional[str]
+ relates_to: Optional[str]
+ membership: Optional[str]
+ rejected: bool
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeId = "state"
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str
- event_id = attr.ib() # str, optional
+ room_id: str
+ type: str
+ state_key: str
+ event_id: Optional[str]
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 6ec00ce0b..e9bce22a3 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -123,34 +123,25 @@ class BackgroundUpdateStartJobRestServlet(RestServlet):
job_name = body["job_name"]
if job_name == "populate_stats_process_rooms":
- jobs = [
- {
- "update_name": "populate_stats_process_rooms",
- "progress_json": "{}",
- },
- ]
+ jobs = [("populate_stats_process_rooms", "{}", "")]
elif job_name == "regenerate_directory":
jobs = [
- {
- "update_name": "populate_user_directory_createtables",
- "progress_json": "{}",
- "depends_on": "",
- },
- {
- "update_name": "populate_user_directory_process_rooms",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_createtables",
- },
- {
- "update_name": "populate_user_directory_process_users",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_rooms",
- },
- {
- "update_name": "populate_user_directory_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_user_directory_process_users",
- },
+ ("populate_user_directory_createtables", "{}", ""),
+ (
+ "populate_user_directory_process_rooms",
+ "{}",
+ "populate_user_directory_createtables",
+ ),
+ (
+ "populate_user_directory_process_users",
+ "{}",
+ "populate_user_directory_process_rooms",
+ ),
+ (
+ "populate_user_directory_cleanup",
+ "{}",
+ "populate_user_directory_process_users",
+ ),
]
else:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
@@ -158,6 +149,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet):
try:
await self._store.db_pool.simple_insert_many(
table="background_updates",
+ keys=("update_name", "progress_json", "depends_on"),
values=jobs,
desc=f"admin_api_run_{job_name}",
)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 50d88c910..8cd3fa189 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -111,25 +111,37 @@ class DestinationsRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
+ if not await self._store.is_destination_known(destination):
+ raise NotFoundError("Unknown destination")
+
destination_retry_timings = await self._store.get_destination_retry_timings(
destination
)
- if not destination_retry_timings:
- raise NotFoundError("Unknown destination")
-
last_successful_stream_ordering = (
await self._store.get_destination_last_successful_stream_ordering(
destination
)
)
- response = {
+ response: JsonDict = {
"destination": destination,
- "failure_ts": destination_retry_timings.failure_ts,
- "retry_last_ts": destination_retry_timings.retry_last_ts,
- "retry_interval": destination_retry_timings.retry_interval,
"last_successful_stream_ordering": last_successful_stream_ordering,
}
+ if destination_retry_timings:
+ response = {
+ **response,
+ "failure_ts": destination_retry_timings.failure_ts,
+ "retry_last_ts": destination_retry_timings.retry_last_ts,
+ "retry_interval": destination_retry_timings.retry_interval,
+ }
+ else:
+ response = {
+ **response,
+ "failure_ts": None,
+ "retry_last_ts": 0,
+ "retry_interval": 0,
+ }
+
return HTTPStatus.OK, response
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 7236e4027..299f5c9eb 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -466,7 +466,7 @@ class UserMediaRestServlet(RestServlet):
)
deleted_media, total = await self.media_repository.delete_local_media_ids(
- ([row["media_id"] for row in media])
+ [row["media_id"] for row in media]
)
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 6030373eb..efe25fe7e 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -424,7 +424,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
- room_state = await self._event_serializer.serialize_events(events.values(), now)
+ room_state = self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@@ -744,22 +744,17 @@ class RoomEventContextServlet(RestServlet):
)
time_now = self.clock.time_msec()
- results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"],
- time_now,
- bundle_aggregations=True,
+ aggregations = results.pop("aggregations", None)
+ results["events_before"] = self._event_serializer.serialize_events(
+ results["events_before"], time_now, bundle_aggregations=aggregations
)
- results["event"] = await self._event_serializer.serialize_event(
- results["event"],
- time_now,
- bundle_aggregations=True,
+ results["event"] = self._event_serializer.serialize_event(
+ results["event"], time_now, bundle_aggregations=aggregations
)
- results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"],
- time_now,
- bundle_aggregations=True,
+ results["events_after"] = self._event_serializer.serialize_events(
+ results["events_after"], time_now, bundle_aggregations=aggregations
)
- results["state"] = await self._event_serializer.serialize_events(
+ results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 78e795c34..c2617ee30 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -173,12 +173,11 @@ class UserRestServletV2(RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
- ret = await self.admin_handler.get_user(target_user)
-
- if not ret:
+ user_info_dict = await self.admin_handler.get_user(target_user)
+ if not user_info_dict:
raise NotFoundError("User not found")
- return HTTPStatus.OK, ret
+ return HTTPStatus.OK, user_info_dict
async def on_PUT(
self, request: SynapseRequest, user_id: str
@@ -399,10 +398,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True
)
- user = await self.admin_handler.get_user(target_user)
- assert user is not None
+ user_info_dict = await self.admin_handler.get_user(target_user)
+ assert user_info_dict is not None
- return 201, user
+ return HTTPStatus.CREATED, user_info_dict
class UserRegisterServlet(RestServlet):
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 13b72a045..672c82106 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -91,7 +91,7 @@ class EventRestServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- result = await self._event_serializer.serialize_event(event, time_now)
+ result = self._event_serializer.serialize_event(event, time_now)
return 200, result
else:
return 404, "Event not found."
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index acd0c9e13..8e427a96a 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -72,7 +72,7 @@ class NotificationsServlet(RestServlet):
"actions": pa.actions,
"ts": pa.received_ts,
"event": (
- await self._event_serializer.serialize_event(
+ self._event_serializer.serialize_event(
notif_events[pa.event_id],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 5815650ee..8cf5ebaa0 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,28 +19,20 @@ any time to reflect changes in the MSC.
"""
import logging
-from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
-from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import (
- RestServlet,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
-)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
-from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client._base import client_patterns
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
-from synapse.util.stringutils import random_string
-
-from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -48,112 +40,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class RelationSendServlet(RestServlet):
- """Helper API for sending events that have relation data.
-
- Example API shape to send a š reaction to a room:
-
- POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
- {}
-
- {
- "event_id": "$foobar"
- }
- """
-
- PATTERN = (
- "/rooms/(?P[^/]*)/send_relation"
- "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)"
- )
-
- def __init__(self, hs: "HomeServer"):
- super().__init__()
- self.auth = hs.get_auth()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.txns = HttpTransactionCache(hs)
-
- def register(self, http_server: HttpServer) -> None:
- http_server.register_paths(
- "POST",
- client_patterns(self.PATTERN + "$", releases=()),
- self.on_PUT_or_POST,
- self.__class__.__name__,
- )
- http_server.register_paths(
- "PUT",
- client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()),
- self.on_PUT,
- self.__class__.__name__,
- )
-
- def on_PUT(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: str,
- event_type: str,
- txn_id: Optional[str] = None,
- ) -> Awaitable[Tuple[int, JsonDict]]:
- return self.txns.fetch_or_execute_request(
- request,
- self.on_PUT_or_POST,
- request,
- room_id,
- parent_id,
- relation_type,
- event_type,
- txn_id,
- )
-
- async def on_PUT_or_POST(
- self,
- request: SynapseRequest,
- room_id: str,
- parent_id: str,
- relation_type: str,
- event_type: str,
- txn_id: Optional[str] = None,
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
- if event_type == EventTypes.Member:
- # Add relations to a membership is meaningless, so we just deny it
- # at the CS API rather than trying to handle it correctly.
- raise SynapseError(400, "Cannot send member events with relations")
-
- content = parse_json_object_from_request(request)
-
- aggregation_key = parse_string(request, "key", encoding="utf-8")
-
- content["m.relates_to"] = {
- "event_id": parent_id,
- "rel_type": relation_type,
- }
- if aggregation_key is not None:
- content["m.relates_to"]["key"] = aggregation_key
-
- event_dict = {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- }
-
- try:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict=event_dict, txn_id=txn_id
- )
- event_id = event.event_id
- except ShadowBanError:
- event_id = "$" + random_string(43)
-
- return 200, {"event_id": event_id}
-
-
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
@@ -227,13 +113,16 @@ class RelationPaginationServlet(RestServlet):
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
- original_event = await self._event_serializer.serialize_event(
- event, now, bundle_aggregations=False
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
- serialized_events = await self._event_serializer.serialize_events(
- events, now, bundle_aggregations=True
+ aggregations = await self.store.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
)
return_value = pagination_chunk.to_dict()
@@ -422,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
)
now = self.clock.time_msec()
- serialized_events = await self._event_serializer.serialize_events(events, now)
+ serialized_events = self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = serialized_events
@@ -431,7 +320,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 40330749e..90bb9142a 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -642,6 +642,7 @@ class RoomEventServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
+ self._store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@@ -660,10 +661,15 @@ class RoomEventServlet(RestServlet):
# https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
- time_now = self.clock.time_msec()
if event:
- event_dict = await self._event_serializer.serialize_event(
- event, time_now, bundle_aggregations=True
+ # Ensure there are bundled aggregations available.
+ aggregations = await self._store.get_bundled_aggregations(
+ [event], requester.user.to_string()
+ )
+
+ time_now = self.clock.time_msec()
+ event_dict = self._event_serializer.serialize_event(
+ event, time_now, bundle_aggregations=aggregations
)
return 200, event_dict
@@ -708,16 +714,17 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=True
+ aggregations = results.pop("aggregations", None)
+ results["events_before"] = self._event_serializer.serialize_events(
+ results["events_before"], time_now, bundle_aggregations=aggregations
)
- results["event"] = await self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=True
+ results["event"] = self._event_serializer.serialize_event(
+ results["event"], time_now, bundle_aggregations=aggregations
)
- results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=True
+ results["events_after"] = self._event_serializer.serialize_events(
+ results["events_after"], time_now, bundle_aggregations=aggregations
)
- results["state"] = await self._event_serializer.serialize_events(
+ results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e99a943d0..d20ae1421 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -17,7 +17,6 @@ from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
- Awaitable,
Callable,
Dict,
Iterable,
@@ -395,7 +394,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = await self._event_serializer.serialize_event(
+ invite = self._event_serializer.serialize_event(
room.invite,
time_now,
token_id=token_id,
@@ -432,7 +431,7 @@ class SyncRestServlet(RestServlet):
"""
knocked = {}
for room in rooms:
- knock = await self._event_serializer.serialize_event(
+ knock = self._event_serializer.serialize_event(
room.knock,
time_now,
token_id=token_id,
@@ -525,21 +524,14 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format
"""
- def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
+ def serialize(
+ events: Iterable[EventBase],
+ aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+ ) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
time_now=time_now,
- # Don't bother to bundle aggregations if the timeline is unlimited,
- # as clients will have all the necessary information.
- # bundle_aggregations=room.timeline.limited,
- #
- # richvdh 2021-12-15: disable this temporarily as it has too high an
- # overhead for initialsyncs. We need to figure out a way that the
- # bundling can be done *before* the events are stored in the
- # SyncResponseCache so that this part can be synchronous.
- #
- # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations.
- bundle_aggregations=False,
+ bundle_aggregations=aggregations,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
@@ -561,8 +553,10 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
- serialized_state = await serialize(state_events)
- serialized_timeline = await serialize(timeline_events)
+ serialized_state = serialize(state_events)
+ serialized_timeline = serialize(
+ timeline_events, room.timeline.bundled_aggregations
+ )
account_data = room.account_data
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index fca239d8c..9f6c251ca 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -343,7 +343,7 @@ class SpamMediaException(NotFoundError):
"""
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class ReadableFileWrapper:
"""Wrapper that allows reading a file in chunks, yielding to the reactor,
and writing to a callback.
@@ -354,8 +354,8 @@ class ReadableFileWrapper:
CHUNK_SIZE = 2 ** 14
- clock = attr.ib(type=Clock)
- path = attr.ib(type=str)
+ clock: Clock
+ path: str
async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index cce1527ed..2177b46c9 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
class OEmbedResult:
# The Open Graph result (converted from the oEmbed result).
open_graph_result: JsonDict
+ # The author_name of the oEmbed result
+ author_name: Optional[str]
# Number of milliseconds to cache the content, according to the oEmbed response.
#
# This will be None if no cache-age is provided in the oEmbed response (or
@@ -154,11 +156,12 @@ class OEmbedProvider:
"og:url": url,
}
- # Use either title or author's name as the title.
- title = oembed.get("title") or oembed.get("author_name")
+ title = oembed.get("title")
if title:
open_graph_response["og:title"] = title
+ author_name = oembed.get("author_name")
+
# Use the provider name and as the site.
provider_name = oembed.get("provider_name")
if provider_name:
@@ -193,9 +196,10 @@ class OEmbedProvider:
# Trap any exception and let the code follow as usual.
logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
open_graph_response = {}
+ author_name = None
cache_age = None
- return OEmbedResult(open_graph_response, cache_age)
+ return OEmbedResult(open_graph_response, author_name, cache_age)
def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a3829d943..e8881bc87 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -262,6 +262,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# The number of milliseconds that the response should be considered valid.
expiration_ms = media_info.expires
+ author_name: Optional[str] = None
if _is_media(media_info.media_type):
file_id = media_info.filesystem_id
@@ -294,17 +295,25 @@ class PreviewUrlResource(DirectServeJsonResource):
# Check if this HTML document points to oEmbed information and
# defer to that.
oembed_url = self._oembed.autodiscover_from_html(tree)
- og = {}
+ og_from_oembed: JsonDict = {}
if oembed_url:
oembed_info = await self._download_url(oembed_url, user)
- og, expiration_ms = await self._handle_oembed_response(
+ (
+ og_from_oembed,
+ author_name,
+ expiration_ms,
+ ) = await self._handle_oembed_response(
url, oembed_info, expiration_ms
)
- # If there was no oEmbed URL (or oEmbed parsing failed), attempt
- # to generate the Open Graph information from the HTML.
- if not oembed_url or not og:
- og = parse_html_to_open_graph(tree, media_info.uri)
+ # Parse Open Graph information from the HTML in case the oEmbed
+ # response failed or is incomplete.
+ og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+
+ # Compile the Open Graph response by using the scraped
+ # information from the HTML and overlaying any information
+ # from the oEmbed response.
+ og = {**og_from_html, **og_from_oembed}
await self._precache_image_url(user, media_info, og)
else:
@@ -312,7 +321,7 @@ class PreviewUrlResource(DirectServeJsonResource):
elif oembed_url:
# Handle the oEmbed information.
- og, expiration_ms = await self._handle_oembed_response(
+ og, author_name, expiration_ms = await self._handle_oembed_response(
url, media_info, expiration_ms
)
await self._precache_image_url(user, media_info, og)
@@ -321,6 +330,11 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Failed to find any OG data in %s", url)
og = {}
+ # If we don't have a title but we have author_name, copy it as
+ # title
+ if not og.get("og:title") and author_name:
+ og["og:title"] = author_name
+
# filter out any stupidly long values
keys_to_remove = []
for k, v in og.items():
@@ -484,7 +498,7 @@ class PreviewUrlResource(DirectServeJsonResource):
async def _handle_oembed_response(
self, url: str, media_info: MediaInfo, expiration_ms: int
- ) -> Tuple[JsonDict, int]:
+ ) -> Tuple[JsonDict, Optional[str], int]:
"""
Parse the downloaded oEmbed info.
@@ -497,11 +511,12 @@ class PreviewUrlResource(DirectServeJsonResource):
Returns:
A tuple of:
The Open Graph dictionary, if the oEmbed info can be parsed.
+ The author name if it could be retrieved from oEmbed.
The (possibly updated) length of time, in milliseconds, the media is valid for.
"""
# If JSON was not returned, there's nothing to do.
if not _is_json(media_info.media_type):
- return {}, expiration_ms
+ return {}, None, expiration_ms
with open(media_info.filename, "rb") as file:
body = file.read()
@@ -513,7 +528,7 @@ class PreviewUrlResource(DirectServeJsonResource):
if open_graph_result and oembed_response.cache_age is not None:
expiration_ms = oembed_response.cache_age
- return open_graph_result, expiration_ms
+ return open_graph_result, oembed_response.author_name, expiration_ms
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
diff --git a/synapse/server.py b/synapse/server.py
index 185e40e4d..3032f0b73 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -759,7 +759,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
- return EventClientSerializer(self)
+ return EventClientSerializer()
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 69ac8c342..67e8bc6ec 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -45,7 +45,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import ContextResourceUsage
-from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
@@ -450,19 +449,19 @@ class StateHandler:
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _StateResMetrics:
"""Keeps track of some usage metrics about state res."""
# System and User CPU time, in seconds
- cpu_time = attr.ib(type=float, default=0.0)
+ cpu_time: float = 0.0
# time spent on database transactions (excluding scheduling time). This roughly
# corresponds to the amount of work done on the db server, excluding event fetches.
- db_time = attr.ib(type=float, default=0.0)
+ db_time: float = 0.0
# number of events fetched from the db.
- db_events = attr.ib(type=int, default=0)
+ db_events: int = 0
_biggest_room_by_cpu_counter = Counter(
@@ -512,7 +511,6 @@ class StateResolutionHandler:
self.clock.looping_call(self._report_metrics, 120 * 1000)
- @log_function
async def resolve_state_groups(
self,
room_id: str,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 2cacc7dd6..57cc1d76e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -143,7 +143,7 @@ def make_conn(
return db_conn
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class LoggingDatabaseConnection:
"""A wrapper around a database connection that returns `LoggingTransaction`
as its cursor class.
@@ -151,9 +151,9 @@ class LoggingDatabaseConnection:
This is mainly used on startup to ensure that queries get logged correctly
"""
- conn = attr.ib(type=Connection)
- engine = attr.ib(type=BaseDatabaseEngine)
- default_txn_name = attr.ib(type=str)
+ conn: Connection
+ engine: BaseDatabaseEngine
+ default_txn_name: str
def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
@@ -934,56 +934,6 @@ class DatabasePool:
txn.execute(sql, vals)
async def simple_insert_many(
- self, table: str, values: List[Dict[str, Any]], desc: str
- ) -> None:
- """Executes an INSERT query on the named table.
-
- The input is given as a list of dicts, with one dict per row.
- Generally simple_insert_many_values should be preferred for new code.
-
- Args:
- table: string giving the table name
- values: dict of new column names and values for them
- desc: description of the transaction, for logging and metrics
- """
- await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
-
- @staticmethod
- def simple_insert_many_txn(
- txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
- ) -> None:
- """Executes an INSERT query on the named table.
-
- The input is given as a list of dicts, with one dict per row.
- Generally simple_insert_many_values_txn should be preferred for new code.
-
- Args:
- txn: The transaction to use.
- table: string giving the table name
- values: dict of new column names and values for them
- """
- if not values:
- return
-
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
- )
-
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
-
- return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
-
- async def simple_insert_many_values(
self,
table: str,
keys: Collection[str],
@@ -1002,11 +952,11 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_insert_many_values_txn, table, keys, values
+ desc, self.simple_insert_many_txn, table, keys, values
)
@staticmethod
- def simple_insert_many_values_txn(
+ def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 32a553fdd..ef475e18c 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -450,7 +450,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
- """Add some account_data to a room for a user.
+ """Add some global account_data for a user.
Args:
user_id: The user to add a tag for.
@@ -536,9 +536,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="ignored_users",
+ keys=("ignorer_user_id", "ignored_user_id"),
values=[
- {"ignorer_user_id": user_id, "ignored_user_id": u}
- for u in currently_ignored_users - previously_ignored_users
+ (user_id, u) for u in currently_ignored_users - previously_ignored_users
],
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3682cb6a8..4eca97189 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -432,14 +432,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_federation_outbox",
+ keys=(
+ "destination",
+ "stream_id",
+ "queued_ts",
+ "messages_json",
+ "instance_name",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": stream_id,
- "queued_ts": now_ms,
- "messages_json": json_encoder.encode(edu),
- "instance_name": self._instance_name,
- }
+ (
+ destination,
+ stream_id,
+ now_ms,
+ json_encoder.encode(edu),
+ self._instance_name,
+ )
for destination, edu in remote_messages_by_destination.items()
],
)
@@ -571,14 +578,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_inbox",
+ keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "message_json": message_json,
- "instance_name": self._instance_name,
- }
+ (user_id, device_id, stream_id, message_json, self._instance_name)
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index bc7e87604..b2a5cd9a6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -53,6 +53,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+issue_8631_logger = logging.getLogger("synapse.8631_debug")
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
@@ -229,6 +230,12 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ data = {(user, device): stream_id for user, device, stream_id, _ in updates}
+ issue_8631_logger.debug(
+ "device updates need to be sent to %s: %s", destination, data
+ )
+
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = {r[0] for r in updates}
@@ -365,6 +372,17 @@ class DeviceWorkerStore(SQLBaseStore):
# and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result))
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ for (user_id, edu) in results:
+ issue_8631_logger.debug(
+ "device update to %s for %s from %s to %s: %s",
+ destination,
+ user_id,
+ from_stream_id,
+ last_processed_stream_id,
+ edu,
+ )
+
return last_processed_stream_id, results
def _get_device_updates_by_remote_txn(
@@ -781,7 +799,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
- ) -> Optional[Any]:
+ ) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
@@ -797,7 +815,9 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
+ async def get_device_list_last_stream_id_for_remotes(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@@ -1384,6 +1404,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
content: JsonDict,
stream_id: str,
) -> None:
+ """Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
@@ -1443,6 +1464,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
+ """Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
@@ -1450,12 +1472,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
+ keys=("user_id", "device_id", "content"),
values=[
- {
- "user_id": user_id,
- "device_id": content["device_id"],
- "content": json_encoder.encode(content),
- }
+ (user_id, content["device_id"], json_encoder.encode(content))
for content in devices
],
)
@@ -1543,8 +1562,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
+ keys=("stream_id", "user_id", "device_id"),
values=[
- {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
+ (stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@@ -1571,18 +1591,27 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
+ keys=(
+ "destination",
+ "stream_id",
+ "user_id",
+ "device_id",
+ "sent",
+ "ts",
+ "opentracing_context",
+ ),
values=[
- {
- "destination": destination,
- "stream_id": next(next_stream_id),
- "user_id": user_id,
- "device_id": device_id,
- "sent": False,
- "ts": now,
- "opentracing_context": json_encoder.encode(context)
+ (
+ destination,
+ next(next_stream_id),
+ user_id,
+ device_id,
+ False,
+ now,
+ json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
- }
+ )
for destination in hosts
for device_id in device_ids
],
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index f76c6121e..5903fdaf0 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -112,10 +112,8 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="room_alias_servers",
- values=[
- {"room_alias": room_alias.to_string(), "server": server}
- for server in servers
- ],
+ keys=("room_alias", "server"),
+ values=[(room_alias.to_string(), server) for server in servers],
)
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 0cb48b9dd..b789a588a 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -110,16 +110,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
- {
- "user_id": user_id,
- "version": version_int,
- "room_id": room_id,
- "session_id": session_id,
- "first_message_index": room_key["first_message_index"],
- "forwarded_count": room_key["forwarded_count"],
- "is_verified": room_key["is_verified"],
- "session_data": json_encoder.encode(room_key["session_data"]),
- }
+ (
+ user_id,
+ version_int,
+ room_id,
+ session_id,
+ room_key["first_message_index"],
+ room_key["forwarded_count"],
+ room_key["is_verified"],
+ json_encoder.encode(room_key["session_data"]),
+ )
)
log_kv(
{
@@ -131,7 +131,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
await self.db_pool.simple_insert_many(
- table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
+ table="e2e_room_keys",
+ keys=(
+ "user_id",
+ "version",
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ values=values,
+ desc="add_e2e_room_keys",
)
@trace
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 57b5ffbad..1f8447b50 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -50,16 +50,16 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
- display_name = attr.ib(type=Optional[str])
+ display_name: Optional[str]
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
- keys = attr.ib(type=Optional[JsonDict])
+ keys: Optional[JsonDict]
class EndToEndKeyBackgroundStore(SQLBaseStore):
@@ -387,15 +387,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
for algorithm, key_id, json_bytes in new_keys
],
)
@@ -1186,15 +1187,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
- [
- {
- "user_id": user_id,
- "key_id": item.signing_key_id,
- "target_user_id": item.target_user_id,
- "target_device_id": item.target_device_id,
- "signature": item.signature,
- }
+ keys=(
+ "user_id",
+ "key_id",
+ "target_user_id",
+ "target_device_id",
+ "signature",
+ ),
+ values=[
+ (
+ user_id,
+ item.signing_key_id,
+ item.target_user_id,
+ item.target_device_id,
+ item.signature,
+ )
for item in signatures
],
- "add_e2e_signing_key",
+ desc="add_e2e_signing_key",
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index a98e6b259..b7c4c6222 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -875,14 +875,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="event_push_summary",
+ keys=(
+ "user_id",
+ "room_id",
+ "notif_count",
+ "unread_count",
+ "stream_ordering",
+ ),
values=[
- {
- "user_id": user_id,
- "room_id": room_id,
- "notif_count": summary.notif_count,
- "unread_count": summary.unread_count,
- "stream_ordering": summary.stream_ordering,
- }
+ (
+ user_id,
+ room_id,
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ )
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is None
],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index dd255aefb..1ae1ebe10 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -39,7 +39,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -69,7 +68,7 @@ event_counter = Counter(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -80,9 +79,9 @@ class DeltaState:
should e.g. be removed from `current_state_events` table.
"""
- to_delete = attr.ib(type=List[Tuple[str, str]])
- to_insert = attr.ib(type=StateMap[str])
- no_longer_in_room = attr.ib(type=bool, default=False)
+ to_delete: List[Tuple[str, str]]
+ to_insert: StateMap[str]
+ no_longer_in_room: bool = False
class PersistEventsStore:
@@ -328,7 +327,6 @@ class PersistEventsStore:
return existing_prevs
- @log_function
def _persist_events_txn(
self,
txn: LoggingTransaction,
@@ -442,12 +440,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- }
+ (event.event_id, event.room_id, auth_id)
for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
@@ -675,8 +670,9 @@ class PersistEventsStore:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
values=[
- {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ (event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
@@ -782,13 +778,14 @@ class PersistEventsStore:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
+ keys=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
values=[
- {
- "origin_chain_id": source_id,
- "origin_sequence_number": source_seq,
- "target_chain_id": target_id,
- "target_sequence_number": target_seq,
- }
+ (source_id, source_seq, target_id, target_seq)
for (
source_id,
source_seq,
@@ -943,20 +940,28 @@ class PersistEventsStore:
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
to_insert.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "user_id": event.sender,
- "token_id": token_id,
- "txn_id": txn_id,
- "inserted_ts": self._clock.time_msec(),
- }
+ (
+ event.event_id,
+ event.room_id,
+ event.sender,
+ token_id,
+ txn_id,
+ self._clock.time_msec(),
+ )
)
if to_insert:
self.db_pool.simple_insert_many_txn(
txn,
table="event_txn_id",
+ keys=(
+ "event_id",
+ "room_id",
+ "user_id",
+ "token_id",
+ "txn_id",
+ "inserted_ts",
+ ),
values=to_insert,
)
@@ -1161,8 +1166,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
+ keys=("event_id", "room_id"),
values=[
- {"event_id": ev_id, "room_id": room_id}
+ (ev_id, room_id)
for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
@@ -1174,12 +1180,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
+ keys=("room_id", "event_id", "stream_ordering"),
values=[
- {
- "room_id": room_id,
- "event_id": event_id,
- "stream_ordering": max_stream_order,
- }
+ (room_id, event_id, max_stream_order)
for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
@@ -1251,20 +1254,22 @@ class PersistEventsStore:
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
- def _update_outliers_txn(self, txn, events_and_contexts):
+ def _update_outliers_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Update any outliers with new event info.
- This turns outliers into ex-outliers (unless the new event was
- rejected).
+ This turns outliers into ex-outliers (unless the new event was rejected), and
+ also removes any other events we have already seen from the list.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without events which
- are already in the events table.
+ new list, without events which are already in the events table.
"""
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1272,7 +1277,9 @@ class PersistEventsStore:
[event.event_id for event, _ in events_and_contexts],
)
- have_persisted = {event_id: outlier for event_id, outlier in txn}
+ have_persisted: Dict[str, bool] = {
+ event_id: outlier for event_id, outlier in txn
+ }
to_remove = set()
for event, context in events_and_contexts:
@@ -1282,15 +1289,22 @@ class PersistEventsStore:
to_remove.add(event)
if context.rejected:
- # If the event is rejected then we don't care if the event
- # was an outlier or not.
+ # If the incoming event is rejected then we don't care if the event
+ # was an outlier or not - what we have is at least as good.
continue
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
- # an outlier in the database. We now have some state at that
+ # an outlier in the database. We now have some state at that event
# so we need to update the state_groups table with that state.
+ #
+ # Note that we do not update the stream_ordering of the event in this
+ # scenario. XXX: does this cause bugs? It will mean we won't send such
+ # events down /sync. In general they will be historical events, so that
+ # doesn't matter too much, but that is not always the case.
+
+ logger.info("Updating state for ex-outlier event %s", event.event_id)
# insert into event_to_state_groups.
try:
@@ -1342,7 +1356,7 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
@@ -1358,7 +1372,7 @@ class PersistEventsStore:
),
)
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="events",
keys=(
@@ -1412,7 +1426,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, [False] + args)
- self.db_pool.simple_insert_many_values_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_events",
keys=("event_id", "room_id", "type", "state_key"),
@@ -1622,14 +1636,9 @@ class PersistEventsStore:
return self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": room_id,
- "topological_ordering": topological_ordering,
- }
- for label in labels
+ (event_id, label, room_id, topological_ordering) for label in labels
],
)
@@ -1657,16 +1666,13 @@ class PersistEventsStore:
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append(
- {
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": memoryview(ref_hash_bytes),
- }
- )
+ vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
self.db_pool.simple_insert_many_txn(
- txn, table="event_reference_hashes", values=vals
+ txn,
+ table="event_reference_hashes",
+ keys=("event_id", "algorithm", "hash"),
+ values=vals,
)
def _store_room_members_txn(
@@ -1689,18 +1695,25 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
+ keys=(
+ "event_id",
+ "user_id",
+ "sender",
+ "room_id",
+ "membership",
+ "display_name",
+ "avatar_url",
+ ),
values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": non_null_str_or_none(
- event.content.get("displayname")
- ),
- "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
- }
+ (
+ event.event_id,
+ event.state_key,
+ event.user_id,
+ event.room_id,
+ event.membership,
+ non_null_str_or_none(event.content.get("displayname")),
+ non_null_str_or_none(event.content.get("avatar_url")),
+ )
for event in events
],
)
@@ -1791,6 +1804,13 @@ class PersistEventsStore:
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
+ # It should be safe to only invalidate the cache if the user has not
+ # previously participated in the thread, but that's difficult (and
+ # potentially error-prone) so it is always invalidated.
+ txn.call_after(
+ self.store.get_thread_participated.invalidate,
+ (parent_id, event.room_id, event.sender),
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
@@ -2163,13 +2183,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
+ keys=("event_id", "prev_event_id", "room_id", "is_state"),
values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
+ (ev.event_id, e_id, ev.room_id, False)
for ev in events
for e_id in ev.prev_event_ids()
],
@@ -2226,17 +2242,17 @@ class PersistEventsStore:
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _LinkMap:
"""A helper type for tracking links between chains."""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
- maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+ maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)
# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
- additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+ additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)
def add_link(
self,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index a68f14ba4..d5f005966 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -65,22 +65,22 @@ class _BackgroundUpdates:
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
# The last room_id/depth/stream processed.
- room_id = attr.ib(type=str)
- depth = attr.ib(type=int)
- stream = attr.ib(type=int)
+ room_id: str
+ depth: int
+ stream: int
# Number of rows processed
- processed_count = attr.ib(type=int)
+ processed_count: int
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
- finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+ finished_room_map: Dict[str, Tuple[int, int]]
class EventsBackgroundUpdatesStore(SQLBaseStore):
@@ -684,13 +684,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
+ keys=("event_id", "label", "room_id", "topological_ordering"),
values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": event_json["room_id"],
- "topological_ordering": event_json["depth"],
- }
+ (
+ event_id,
+ label,
+ event_json["room_id"],
+ event_json["depth"],
+ )
for label in event_json["content"].get(
EventContentFields.LABELS, []
)
@@ -803,29 +804,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not has_state:
state_events.append(
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
+ (event.event_id, event.room_id, event.type, event.state_key)
)
if not has_event_auth:
# Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
- auth_events.append(
- {
- "room_id": event.room_id,
- "event_id": event.event_id,
- "auth_id": auth_id,
- }
- )
+ auth_events.append((event.event_id, event.room_id, auth_id))
if state_events:
await self.db_pool.simple_insert_many(
table="state_events",
+ keys=("event_id", "room_id", "type", "state_key"),
values=state_events,
desc="_rejected_events_metadata_state_events",
)
@@ -833,6 +824,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if auth_events:
await self.db_pool.simple_insert_many(
table="event_auth",
+ keys=("event_id", "room_id", "auth_id"),
values=auth_events,
desc="_rejected_events_metadata_event_auth",
)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cbf9ec38f..4f05811a7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -129,18 +129,29 @@ class PresenceStore(PresenceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="presence_stream",
+ keys=(
+ "stream_id",
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ "instance_name",
+ ),
values=[
- {
- "stream_id": stream_id,
- "user_id": state.user_id,
- "state": state.state,
- "last_active_ts": state.last_active_ts,
- "last_federation_update_ts": state.last_federation_update_ts,
- "last_user_sync_ts": state.last_user_sync_ts,
- "status_msg": state.status_msg,
- "currently_active": state.currently_active,
- "instance_name": self._instance_name,
- }
+ (
+ stream_id,
+ state.user_id,
+ state.state,
+ state.last_active_ts,
+ state.last_federation_update_ts,
+ state.last_user_sync_ts,
+ state.status_msg,
+ state.currently_active,
+ self._instance_name,
+ )
for stream_id, state in zip(stream_orderings, presence_states)
],
)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 747b4f31d..cf64cd63a 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -561,13 +561,9 @@ class PusherStore(PusherWorkerStore):
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
+ keys=("stream_id", "app_id", "pushkey", "user_id"),
values=[
- {
- "stream_id": stream_id,
- "app_id": pusher.app_id,
- "pushkey": pusher.pushkey,
- "user_id": user_id,
- }
+ (stream_id, pusher.app_id, pusher.pushkey, user_id)
for stream_id, pusher in zip(stream_ids, pushers)
],
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4175c82a2..aac94fa46 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception):
pass
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -69,14 +69,14 @@ class TokenLookupResult:
cached.
"""
- user_id = attr.ib(type=str)
- is_guest = attr.ib(type=bool, default=False)
- shadow_banned = attr.ib(type=bool, default=False)
- token_id = attr.ib(type=Optional[int], default=None)
- device_id = attr.ib(type=Optional[str], default=None)
- valid_until_ms = attr.ib(type=Optional[int], default=None)
- token_owner = attr.ib(type=str)
- token_used = attr.ib(type=bool, default=False)
+ user_id: str
+ is_guest: bool = False
+ shadow_banned: bool = False
+ token_id: Optional[int] = None
+ device_id: Optional[str] = None
+ valid_until_ms: Optional[int] = None
+ token_owner: str = attr.ib()
+ token_used: bool = False
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed25..2cb5d06c1 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
+from frozendict import frozendict
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
)
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+ self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
@cached(tree=True)
async def get_relations_for_event(
self,
@@ -354,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
- """Get the number of threaded replies, the senders of those replies, and
- the latest reply (if any) for the given event.
+ """Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
@@ -368,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
- # Fetch the count of threaded events and the latest event ID.
+ # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@@ -389,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
+ # Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
@@ -413,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ @cached()
+ async def get_thread_participated(
+ self, event_id: str, room_id: str, user_id: str
+ ) -> bool:
+ """Get whether the requesting user participated in a thread.
+
+ This is separate from get_thread_summary since that can be cached across
+ all users while this value is specific to the requeser.
+
+ Args:
+ event_id: The thread related to this event ID.
+ room_id: The room the event belongs to.
+ user_id: The user requesting the summary.
+
+ Returns:
+ True if the requesting user participated in the thread, otherwise false.
+ """
+
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ # Fetch whether the requester has participated or not.
+ sql = """
+ SELECT 1
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND room_id = ?
+ AND relation_type = ?
+ AND sender = ?
+ """
+
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+ return bool(txn.fetchone())
+
+ return await self.db_pool.runInteraction(
+ "get_thread_summary", _get_thread_summary_txn
+ )
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -515,6 +583,104 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+ # State events and redacted events do not get bundled aggregations.
+ if event.is_state() or event.internal_metadata.is_redacted():
+ return None
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations: Dict[str, Any] = {}
+
+ annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+ if annotations.chunk:
+ aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ references = await self.get_relations_for_event(
+ event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = await self.get_applicable_edit(event_id, room_id)
+
+ if edit:
+ aggregations[RelationTypes.REPLACE] = edit
+
+ # If this event is the start of a thread, include a summary of the replies.
+ if self._msc3440_enabled:
+ thread_count, latest_thread_event = await self.get_thread_summary(
+ event_id, room_id
+ )
+ participated = await self.get_thread_participated(
+ event_id, room_id, user_id
+ )
+ if latest_thread_event:
+ aggregations[RelationTypes.THREAD] = {
+ "latest_event": latest_thread_event,
+ "count": thread_count,
+ "current_user_participated": participated,
+ }
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self,
+ events: Iterable[EventBase],
+ user_id: str,
+ ) -> Dict[str, Dict[str, Any]]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # If bundled aggregations are disabled, nothing to do.
+ if not self._msc1849_enabled:
+ return {}
+
+ # TODO Parallelize.
+ results = {}
+ for event in events:
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result is not None:
+ results[event.event_id] = event_result
+
+ return results
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index c0e837854..95167116c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -551,24 +551,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
- %s
- ORDER BY %s %s
+ {where}
+ ORDER BY {order_by} {direction}, state.room_id {direction}
LIMIT ?
OFFSET ?
- """ % (
- where_statement,
- order_by_column,
- "ASC" if order_by_asc else "DESC",
+ """.format(
+ where=where_statement,
+ order_by=order_by_column,
+ direction="ASC" if order_by_asc else "DESC",
)
# Use a nested SELECT statement as SQL can't count(*) with an OFFSET
count_sql = """
SELECT count(*) FROM (
SELECT room_id FROM room_stats_state state
- %s
+ {where}
) AS get_room_ids
- """ % (
- where_statement,
+ """.format(
+ where=where_statement,
)
def _get_rooms_paginate_txn(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cda80d651..4489732fd 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1177,18 +1177,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
await self.db_pool.runInteraction("forget_membership", f)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""
# Dict of host to the set of their users in the room at the state group.
- hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
+ hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)
# The state group `hosts_to_joined_users` is derived from. Will be an object
# if the instance is newly created or if the state is not based on a state
# group. (An object is used as a sentinel value to ensure that it never is
# equal to anything else).
- state_group = attr.ib(type=Union[object, int], factory=object)
+ state_group: Union[object, int] = attr.Factory(object)
def __len__(self):
return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
index 5a9712043..e8c776b97 100644
--- a/synapse/storage/databases/main/session.py
+++ b/synapse/storage/databases/main/session.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6c299cafa..4b78b4d09 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -560,3 +560,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction(
"get_destinations_paginate_txn", get_destinations_paginate_txn
)
+
+ async def is_destination_known(self, destination: str) -> bool:
+ """Check if a destination is known to the server."""
+ result = await self.db_pool.simple_select_one_onecol(
+ table="destinations",
+ keyvalues={"destination": destination},
+ retcol="1",
+ allow_none=True,
+ desc="is_destination_known",
+ )
+ return bool(result)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index a1a1a6a14..2d339b600 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,19 +23,19 @@ from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class UIAuthSessionData:
- session_id = attr.ib(type=str)
+ session_id: str
# The dictionary from the client root level, not the 'auth' key.
- clientdict = attr.ib(type=JsonDict)
+ clientdict: JsonDict
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
- uri = attr.ib(type=str)
- method = attr.ib(type=str)
+ uri: str
+ method: str
# A string description of the operation that the current authentication is
# authorising.
- description = attr.ib(type=str)
+ description: str
class UIAuthWorkerStore(SQLBaseStore):
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0f9b8575d..f7c778bdf 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -105,8 +105,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
GROUP BY room_id
"""
txn.execute(sql)
- rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ rooms = list(txn.fetchall())
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms
+ )
del rooms
sql = (
@@ -117,9 +119,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute(sql)
txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ users = list(txn.fetchall())
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(
+ txn, TEMP_TABLE + "_users", keys=("user_id",), values=users
+ )
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index eb1118d2c..5de70f31d 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -327,14 +327,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=(
+ "state_group",
+ "room_id",
+ "type",
+ "state_key",
+ "event_id",
+ ),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_state.items()
],
)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index c4c8c0021..7614d76ac 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -460,14 +460,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_ids.items()
],
)
@@ -475,14 +470,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (state_group, room_id, key[0], key[1], state_id)
for key, state_id in current_state_ids.items()
],
)
@@ -589,14 +579,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
+ (sg, room_id, key[0], key[1], state_id)
for key, state_id in curr_state.items()
],
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 540adb878..71584f3f7 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -21,7 +21,7 @@ from signedjson.types import VerifyKey
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResult:
- verify_key = attr.ib(type=VerifyKey) # the key itself
- valid_until_ts = attr.ib(type=int) # how long we can use this key for
+ verify_key: VerifyKey # the key itself
+ valid_until_ts: int # how long we can use this key for
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e45adfcb5..1823e1872 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -696,7 +696,7 @@ def _get_or_create_schema_state(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
@@ -705,5 +705,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib(type=str)
- absolute_path = attr.ib(type=str)
+ file_name: str
+ absolute_path: str
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 10a46b5e8..b1536c1ca 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PaginationChunk:
"""Returned by relation pagination APIs.
@@ -35,9 +35,9 @@ class PaginationChunk:
None then there are no previous results.
"""
- chunk = attr.ib(type=List[JsonDict])
- next_batch = attr.ib(type=Optional[Any], default=None)
- prev_batch = attr.ib(type=Optional[Any], default=None)
+ chunk: List[JsonDict]
+ next_batch: Optional[Any] = None
+ prev_batch: Optional[Any] = None
def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
@@ -51,7 +51,7 @@ class PaginationChunk:
return d
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class RelationPaginationToken:
"""Pagination token for relation pagination API.
@@ -64,8 +64,8 @@ class RelationPaginationToken:
stream: The stream ordering of the boundary event.
"""
- topological = attr.ib(type=int)
- stream = attr.ib(type=int)
+ topological: int
+ stream: int
@staticmethod
def from_string(string: str) -> "RelationPaginationToken":
@@ -82,7 +82,7 @@ class RelationPaginationToken:
return attr.astuple(self)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class AggregationPaginationToken:
"""Pagination token for relation aggregation pagination API.
@@ -94,8 +94,8 @@ class AggregationPaginationToken:
stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib(type=int)
- stream = attr.ib(type=int)
+ count: int
+ stream: int
@staticmethod
def from_string(string: str) -> "AggregationPaginationToken":
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b5ba1560d..df8b2f108 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -45,7 +45,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class StateFilter:
"""A filter used when querying for state.
@@ -58,8 +58,8 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
- include_others = attr.ib(default=False, type=bool)
+ types: "frozendict[str, Optional[FrozenSet[str]]]"
+ include_others: bool = False
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b8112e1c0..3c13859fa 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -762,13 +762,13 @@ class _AsyncCtxManagerWrapper(Generic[T]):
return self.inner.__exit__(exc_type, exc, tb)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
- id_gen = attr.ib(type=MultiWriterIdGenerator)
- multiple_ids = attr.ib(type=Optional[int], default=None)
- stream_ids = attr.ib(type=List[int], factory=list)
+ id_gen: MultiWriterIdGenerator
+ multiple_ids: Optional[int] = None
+ stream_ids: List[int] = attr.Factory(list)
async def __aenter__(self) -> Union[int, List[int]]:
# It's safe to run this in autocommit mode as fetching values from a
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index c08d591f2..b52723e2b 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class PaginationConfig:
"""A configuration object which stores pagination parameters."""
- from_token = attr.ib(type=Optional[StreamToken])
- to_token = attr.ib(type=Optional[StreamToken])
- direction = attr.ib(type=str)
- limit = attr.ib(type=Optional[int])
+ from_token: Optional[StreamToken]
+ to_token: Optional[StreamToken]
+ direction: str
+ limit: Optional[int]
@classmethod
async def from_request(
diff --git a/synapse/types.py b/synapse/types.py
index 42aeaf627..f89fb216a 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -20,7 +20,9 @@ from typing import (
Any,
ClassVar,
Dict,
+ List,
Mapping,
+ Match,
MutableMapping,
Optional,
Tuple,
@@ -79,7 +81,7 @@ class ISynapseReactor(
"""The interfaces necessary for Synapse to function."""
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class Requester:
"""
Represents the user making a request
@@ -97,13 +99,13 @@ class Requester:
"puppeting" the user.
"""
- user = attr.ib(type="UserID")
- access_token_id = attr.ib(type=Optional[int])
- is_guest = attr.ib(type=bool)
- shadow_banned = attr.ib(type=bool)
- device_id = attr.ib(type=Optional[str])
- app_service = attr.ib(type=Optional["ApplicationService"])
- authenticated_entity = attr.ib(type=str)
+ user: "UserID"
+ access_token_id: Optional[int]
+ is_guest: bool
+ shadow_banned: bool
+ device_id: Optional[str]
+ app_service: Optional["ApplicationService"]
+ authenticated_entity: str
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
@@ -210,7 +212,7 @@ def get_localpart_from_id(string: str) -> str:
DS = TypeVar("DS", bound="DomainSpecificString")
-@attr.s(slots=True, frozen=True, repr=False)
+@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True)
class DomainSpecificString(metaclass=abc.ABCMeta):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -223,8 +225,8 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
- localpart = attr.ib(type=str)
- domain = attr.ib(type=str)
+ localpart: str
+ domain: str
# Because this is a frozen class, it is deeply immutable.
def __copy__(self):
@@ -380,7 +382,7 @@ def map_username_to_mxid_localpart(
onto different mxids
Returns:
- unicode: string suitable for a mxid localpart
+ string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
username = username.encode("utf-8")
@@ -388,29 +390,23 @@ def map_username_to_mxid_localpart(
# first we sort out upper-case characters
if case_sensitive:
- def f1(m):
+ def f1(m: Match[bytes]) -> bytes:
return b"_" + m.group().lower()
username = UPPER_CASE_PATTERN.sub(f1, username)
else:
username = username.lower()
- # then we sort out non-ascii characters
- def f2(m):
- g = m.group()[0]
- if isinstance(g, str):
- # on python 2, we need to do a ord(). On python 3, the
- # byte itself will do.
- g = ord(g)
- return b"=%02x" % (g,)
+ # then we sort out non-ascii characters by converting to the hex equivalent.
+ def f2(m: Match[bytes]) -> bytes:
+ return b"=%02x" % (m.group()[0],)
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
# we also do the =-escaping to mxids starting with an underscore.
username = re.sub(b"^_", b"=5f", username)
- # we should now only have ascii bytes left, so can decode back to a
- # unicode.
+ # we should now only have ascii bytes left, so can decode back to a string.
return username.decode("ascii")
@@ -466,14 +462,12 @@ class RoomStreamToken:
attributes, must be hashable.
"""
- topological = attr.ib(
- type=Optional[int],
+ topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
- stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ stream: int = attr.ib(validator=attr.validators.instance_of(int))
- instance_map = attr.ib(
- type="frozendict[str, int]",
+ instance_map: "frozendict[str, int]" = attr.ib(
factory=frozendict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
@@ -482,7 +476,7 @@ class RoomStreamToken:
),
)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
"""Validates that both `topological` and `instance_map` aren't set."""
if self.instance_map and self.topological:
@@ -598,7 +592,7 @@ class RoomStreamToken:
return "s%d" % (self.stream,)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class StreamToken:
"""A collection of positions within multiple streams.
@@ -606,20 +600,20 @@ class StreamToken:
must be hashable.
"""
- room_key = attr.ib(
- type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+ room_key: RoomStreamToken = attr.ib(
+ validator=attr.validators.instance_of(RoomStreamToken)
)
- presence_key = attr.ib(type=int)
- typing_key = attr.ib(type=int)
- receipt_key = attr.ib(type=int)
- account_data_key = attr.ib(type=int)
- push_rules_key = attr.ib(type=int)
- to_device_key = attr.ib(type=int)
- device_list_key = attr.ib(type=int)
- groups_key = attr.ib(type=int)
+ presence_key: int
+ typing_key: int
+ receipt_key: int
+ account_data_key: int
+ push_rules_key: int
+ to_device_key: int
+ device_list_key: int
+ groups_key: int
_SEPARATOR = "_"
- START: "StreamToken"
+ START: ClassVar["StreamToken"]
@classmethod
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
@@ -679,7 +673,7 @@ class StreamToken:
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition:
"""Position of a newly persisted event with instance that persisted it.
@@ -687,8 +681,8 @@ class PersistedEventPosition:
RoomStreamToken.
"""
- instance_name = attr.ib(type=str)
- stream = attr.ib(type=int)
+ instance_name: str
+ stream: int
def persisted_after(self, token: RoomStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
@@ -738,15 +732,15 @@ class ThirdPartyInstanceID:
__str__ = to_string
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReadReceipt:
"""Information about a read-receipt"""
- room_id = attr.ib()
- receipt_type = attr.ib()
- user_id = attr.ib()
- event_ids = attr.ib()
- data = attr.ib()
+ room_id: str
+ receipt_type: str
+ user_id: str
+ event_ids: List[str]
+ data: JsonDict
def get_verify_key_from_cross_signing_key(key_info):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 150a04b53..3f7299aff 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -309,12 +309,12 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
- count = attr.ib(type=int)
+ count: int
# Deferreds for the things blocked from executing.
- deferreds = attr.ib(type=collections.OrderedDict)
+ deferreds: collections.OrderedDict
class Linearizer:
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 485ddb189..d267703df 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -33,7 +33,7 @@ DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
@@ -41,14 +41,13 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
- known_absent: Keys that were looked up in the dict and were not
- there.
+ known_absent: Keys that were looked up in the dict and were not there.
value: The full or partial dict value
"""
- full = attr.ib(type=bool)
- known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
- value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
+ full: bool
+ known_absent: Set[Any] # should be: Set[DKT]
+ value: Dict[Any, Any] # should be: Dict[DKT, DV]
def __len__(self) -> int:
return len(self.value)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index a2dfa1ed0..4b53b6d40 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -274,6 +274,39 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEquals(failure.value.code, 400)
self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.store.insert_client_ip.assert_called_once()
+
+ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
+ self.auth._track_puppeted_user_ips = True
+ self.store.get_user_by_access_token = simple_async_mock(
+ TokenLookupResult(
+ user_id="@baldrick:matrix.org",
+ device_id="device",
+ token_owner="@admin:matrix.org",
+ )
+ )
+ self.store.insert_client_ip = simple_async_mock(None)
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args[b"access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_success(self.auth.get_user_by_req(request))
+ self.assertEquals(self.store.insert_client_ip.call_count, 2)
+
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee34..734ed84d7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable
from unittest import mock
+from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from tests import unittest
+from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_user_id = "@test:other"
local_user_id = "@test:test"
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
)
+
+ @parameterized.expand(
+ [
+ # The remote homeserver's response indicates that this user has 0/1/2 devices.
+ ([],),
+ (["device_1"],),
+ (["device_1", "device_2"],),
+ ]
+ )
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ """Test that requests for all of a remote user's devices are cached.
+
+ We do this by asserting that only one call over federation was made, and that
+ the two queries to the local homeserver produce the same response.
+ """
+ local_user_id = "@test:test"
+ remote_user_id = "@test:other"
+ request_body = {"device_keys": {remote_user_id: []}}
+
+ response_devices = [
+ {
+ "device_id": device_id,
+ "keys": {
+ "algorithms": ["dummy"],
+ "device_id": device_id,
+ "keys": {f"dummy:{device_id}": "dummy"},
+ "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+ "unsigned": {},
+ "user_id": "@test:other",
+ },
+ }
+ for device_id in device_ids
+ ]
+
+ response_body = {
+ "devices": response_devices,
+ "user_id": remote_user_id,
+ "stream_id": 12345, # an integer, according to the spec
+ }
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
+ mock_get_rooms = mock.patch.object(
+ self.store,
+ "get_rooms_for_user",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(["some_room_id"]),
+ )
+ mock_request = mock.patch.object(
+ self.hs.get_federation_client(),
+ "query_user_devices",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(response_body),
+ )
+
+ with mock_get_rooms, mock_request as mocked_federation_request:
+ # Make the first query and sanity check it succeeds.
+ response_1 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_1["failures"], {})
+
+ # We should have made a federation request to do so.
+ mocked_federation_request.assert_called_once()
+
+ # Reset the mock so we can prove we don't make a second federation request.
+ mocked_federation_request.reset_mock()
+
+ # Repeat the query.
+ response_2 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_2["failures"], {})
+
+ # We should not have made a second federation request.
+ mocked_federation_request.assert_not_called()
+
+ # The two requests to the local homeserver should be identical.
+ self.assertEqual(response_1, response_2)
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 08e9730d4..2add72b28 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login
+from synapse.rest.client import devices, login, logout
from synapse.types import JsonDict
from tests import unittest
@@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
+ logout.register_servlets,
]
def setUp(self):
@@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
+ def test_on_logged_out(self):
+ """Tests that the on_logged_out callback is called when the user logs out."""
+ self.register_user("rin", "password")
+ tok = self.login("rin", "password")
+
+ self.called = False
+
+ async def on_logged_out(user_id, device_id, access_token):
+ self.called = True
+
+ on_logged_out = Mock(side_effect=on_logged_out)
+ self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+ on_logged_out
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/logout",
+ {},
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200)
+ on_logged_out.assert_called_once()
+ self.assertTrue(self.called)
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e5a6a6c74..51b22d299 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -28,6 +28,7 @@ from synapse.api.constants import (
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.transport.client import TransportLayerClient
from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -134,10 +135,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, self.room, self.token)
def _add_child(
- self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+ self,
+ space_id: str,
+ room_id: str,
+ token: str,
+ order: Optional[str] = None,
+ via: Optional[List[str]] = None,
) -> None:
"""Add a child room to a space."""
- content: JsonDict = {"via": [self.hs.hostname]}
+ if via is None:
+ via = [self.hs.hostname]
+
+ content: JsonDict = {"via": via}
if order is not None:
content["order"] = order
self.helper.send_state(
@@ -253,6 +262,38 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_large_space(self):
+ """Test a space with a large number of rooms."""
+ rooms = [self.room]
+ # Make at least 51 rooms that are part of the space.
+ for _ in range(55):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token)
+ rooms.append(room)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The spaces result should have the space and the first 50 rooms in it,
+ # along with the links from space -> room for those 50 rooms.
+ expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]]
+ self._assert_rooms(result, expected)
+
+ # The result should have the space and the rooms in it, along with the links
+ # from space -> room.
+ expected = [(self.space, rooms)] + [(room, []) for room in rooms]
+
+ # Make two requests to fully paginate the results.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ result2 = self.get_success(
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token=result["next_batch"]
+ )
+ )
+ # Combine the results.
+ result["rooms"] += result2["rooms"]
+ self._assert_hierarchy(result, expected)
+
def test_visibility(self):
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
@@ -1004,6 +1045,85 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_fed_caching(self):
+ """
+ Federation `/hierarchy` responses should be cached.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_subspace = "#space:" + fed_hostname
+ fed_room = "#room:" + fed_hostname
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, fed_subspace, self.token, via=[fed_hostname])
+
+ federation_requests = 0
+
+ async def get_room_hierarchy(
+ _self: TransportLayerClient,
+ destination: str,
+ room_id: str,
+ suggested_only: bool,
+ ) -> JsonDict:
+ nonlocal federation_requests
+ federation_requests += 1
+
+ return {
+ "room": {
+ "room_id": fed_subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ "children_state": [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_subspace,
+ "state_key": fed_room,
+ "content": {"via": [fed_hostname]},
+ },
+ ],
+ },
+ "children": [
+ {
+ "room_id": fed_room,
+ "world_readable": True,
+ },
+ ],
+ "inaccessible_children": [],
+ }
+
+ expected = [
+ (self.space, [self.room, fed_subspace]),
+ (self.room, ()),
+ (fed_subspace, [fed_room]),
+ (fed_room, ()),
+ ]
+
+ with mock.patch(
+ "synapse.federation.transport.client.TransportLayerClient.get_room_hierarchy",
+ new=get_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # The previous federation response should be reused.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # Expire the response cache
+ self.reactor.advance(5 * 60 + 1)
+
+ # A new federation request should be made.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 2)
+ self._assert_hierarchy(result, expected)
+
class RoomSummaryTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 638186f17..07a760e91 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -27,6 +26,7 @@ from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
+from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -186,6 +186,97 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+ def test_ban_wins_race_with_join(self):
+ """Rooms shouldn't appear under "joined" if a join loses a race to a ban.
+
+ A complicated edge case. Imagine the following scenario:
+
+ * you attempt to join a room
+ * racing with that is a ban which comes in over federation, which ends up with
+ an earlier stream_ordering than the join.
+ * you get a sync response with a sync token which is _after_ the ban, but before
+ the join
+ * now your join lands; it is a valid event because its `prev_event`s predate the
+ ban, but will not make it into current_state_events (because bans win over
+ joins in state res, essentially).
+ * When we do a sync from the incremental sync, the only event in the timeline
+ is your join ... and yet you aren't joined.
+
+ The ban coming in over federation isn't crucial for this behaviour; the key
+ requirements are:
+ 1. the homeserver generates a join event with prev_events that precede the ban
+ (so that it passes the "are you banned" test)
+ 2. the join event has a stream_ordering after that of the ban.
+
+ We use monkeypatching to artificially trigger condition (1).
+ """
+ # A local user Alice creates a room.
+ owner = self.register_user("alice", "password")
+ owner_tok = self.login(owner, "password")
+ room_id = self.helper.create_room_as(owner, is_public=True, tok=owner_tok)
+
+ # Do a sync as Alice to get the latest event in the room.
+ alice_sync_result: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(owner), generate_sync_config(owner)
+ )
+ )
+ self.assertEqual(len(alice_sync_result.joined), 1)
+ self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
+ last_room_creation_event_id = (
+ alice_sync_result.joined[0].timeline.events[-1].event_id
+ )
+
+ # Eve, a ne'er-do-well, registers.
+ eve = self.register_user("eve", "password")
+ eve_token = self.login(eve, "password")
+
+ # Alice preemptively bans Eve.
+ self.helper.ban(room_id, owner, eve, tok=owner_tok)
+
+ # Eve syncs.
+ eve_requester = create_requester(eve)
+ eve_sync_config = generate_sync_config(eve)
+ eve_sync_after_ban: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
+ )
+
+ # Sanity check this sync result. We shouldn't be joined to the room.
+ self.assertEqual(eve_sync_after_ban.joined, [])
+
+ # Eve tries to join the room. We monkey patch the internal logic which selects
+ # the prev_events used when creating the join event, such that the ban does not
+ # precede the join.
+ mocked_get_prev_events = patch.object(
+ self.hs.get_datastore(),
+ "get_prev_events_for_room",
+ new_callable=MagicMock,
+ return_value=make_awaitable([last_room_creation_event_id]),
+ )
+ with mocked_get_prev_events:
+ self.helper.join(room_id, eve, tok=eve_token)
+
+ # Eve makes a second, incremental sync.
+ eve_incremental_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=eve_sync_after_ban.next_batch,
+ )
+ )
+ # Eve should not see herself as joined to the room.
+ self.assertEqual(eve_incremental_sync_after_join.joined, [])
+
+ # If we did a third initial sync, we should _still_ see eve is not joined to the room.
+ eve_initial_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=None,
+ )
+ )
+ self.assertEqual(eve_initial_sync_after_join.joined, [])
+
_request_key = 0
diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py
new file mode 100644
index 000000000..ee5cf299f
--- /dev/null
+++ b/tests/http/test_webclient.py
@@ -0,0 +1,108 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig
+from synapse.http.site import SynapseSite
+
+from tests.server import make_request
+from tests.unittest import HomeserverTestCase, create_resource_tree, override_config
+
+
+class WebClientTests(HomeserverTestCase):
+ @override_config(
+ {
+ "web_client_location": "https://example.org",
+ }
+ )
+ def test_webclient_resolves_with_client_resource(self):
+ """
+ Tests that both client and webclient resources can be accessed simultaneously.
+
+ This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763.
+ """
+ for resource_name_order_list in [
+ ["webclient", "client"],
+ ["client", "webclient"],
+ ]:
+ # Create a dictionary from path regex -> resource
+ resource_dict: Dict[str, Resource] = {}
+
+ for resource_name in resource_name_order_list:
+ resource_dict.update(
+ SynapseHomeServer._configure_named_resource(self.hs, resource_name)
+ )
+
+ # Create a root resource which ties the above resources together into one
+ root_resource = Resource()
+ create_resource_tree(resource_dict, root_resource)
+
+ # Create a site configured with this resource to make HTTP requests against
+ listener_config = ListenerConfig(
+ port=8008,
+ bind_addresses=["127.0.0.1"],
+ type="http",
+ http_options=HttpListenerConfig(
+ resources=[HttpResourceConfig(names=resource_name_order_list)]
+ ),
+ )
+ test_site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag=self.hs.config.server.server_name,
+ config=listener_config,
+ resource=root_resource,
+ server_version_string="1",
+ max_request_body_size=1234,
+ reactor=self.reactor,
+ )
+
+ # Attempt to make requests to endpoints on both the webclient and client resources
+ # on test_site.
+ self._request_client_and_webclient_resources(test_site)
+
+ def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None:
+ """Make a request to an endpoint on both the webclient and client-server resources
+ of the given SynapseSite.
+
+ Args:
+ test_site: The SynapseSite object to make requests against.
+ """
+
+ # Ensure that the *webclient* resource is behaving as expected (we get redirected to
+ # the configured web_client_location)
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client",
+ )
+ # Check that we are being redirected to the webclient location URI.
+ self.assertEqual(channel.code, HTTPStatus.FOUND)
+ self.assertEqual(
+ channel.headers.getRawHeaders("Location"), ["https://example.org"]
+ )
+
+ # Ensure that a request to the *client* resource works.
+ channel = make_request(
+ self.reactor,
+ site=test_site,
+ method="GET",
+ path="/_matrix/client/v3/login",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertIn("flows", channel.json_body)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 742f19425..b70350b6f 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
retry_interval,
last_successful_stream_ordering,
) in dest:
- self.get_success(
- self.store.set_destination_retry_timings(
- destination, failure_ts, retry_last_ts, retry_interval
- )
- )
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(
- destination, last_successful_stream_ordering
- )
+ self._create_destination(
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ last_successful_stream_ordering,
)
# order by default (destination)
@@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- def test_get_single_destination(self) -> None:
- """
- Get one specific destinations.
- """
- self._create_destinations(5)
+ def test_get_single_destination_with_retry_timings(self) -> None:
+ """Get one specific destination which has retry timings."""
+ self._create_destinations(1)
channel = self.make_request(
"GET",
@@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase):
# convert channel.json_body into a List
self._check_fields([channel.json_body])
+ def test_get_single_destination_no_retry_timings(self) -> None:
+ """Get one specific destination which has no retry timings."""
+ self._create_destination("sub0.example.com")
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/sub0.example.com",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual("sub0.example.com", channel.json_body["destination"])
+ self.assertEqual(0, channel.json_body["retry_last_ts"])
+ self.assertEqual(0, channel.json_body["retry_interval"])
+ self.assertIsNone(channel.json_body["failure_ts"])
+ self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
+
+ def _create_destination(
+ self,
+ destination: str,
+ failure_ts: Optional[int] = None,
+ retry_last_ts: int = 0,
+ retry_interval: int = 0,
+ last_successful_stream_ordering: Optional[int] = None,
+ ) -> None:
+ """Create one specific destination
+
+ Args:
+ destination: the destination we have successfully sent to
+ failure_ts: when the server started failing (ms since epoch)
+ retry_last_ts: time of last retry attempt in unix epoch ms
+ retry_interval: how long until next retry in ms
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ self.get_success(
+ self.store.set_destination_retry_timings(
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ )
+ if last_successful_stream_ordering is not None:
+ self.get_success(
+ self.store.set_destination_last_successful_stream_ordering(
+ destination, last_successful_stream_ordering
+ )
+ )
+
def _create_destinations(self, number_destinations: int) -> None:
"""Create a number of destinations
@@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"""
for i in range(0, number_destinations):
dest = f"sub{i}.example.com"
- self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
- self.get_success(
- self.store.set_destination_last_successful_stream_ordering(dest, 100)
- )
+ self._create_destination(dest, 50, 50, 50, 100)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected destination attributes are present in content
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 81f3ac7f0..8513b1d2d 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -223,20 +223,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# Create all possible single character tokens
tokens = []
for c in string.ascii_letters + string.digits + "._~-":
- tokens.append(
- {
- "token": c,
- "uses_allowed": None,
- "pending": 0,
- "completed": 0,
- "expiry_time": None,
- }
- )
+ tokens.append((c, None, 0, 0, None))
self.get_success(
self.store.db_pool.simple_insert_many(
"registration_tokens",
- tokens,
- "create_all_registration_tokens",
+ keys=("token", "uses_allowed", "pending", "completed", "expiry_time"),
+ values=tokens,
+ desc="create_all_registration_tokens",
)
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index d2c8781cd..3495a0366 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1089,6 +1089,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
room_ids.append(room_id)
+ room_ids.sort()
+
# Request the list of rooms
url = "/_synapse/admin/v1/rooms"
channel = self.make_request(
@@ -1360,6 +1362,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ # Also create a list sorted by IDs for properties that are equal (and thus sorted by room_id)
+ sorted_by_room_id_asc = [room_id_1, room_id_2, room_id_3]
+ sorted_by_room_id_asc.sort()
+ sorted_by_room_id_desc = sorted_by_room_id_asc.copy()
+ sorted_by_room_id_desc.reverse()
+
# Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
self.helper.send_state(
room_id_1,
@@ -1405,41 +1413,42 @@ class RoomTestCase(unittest.HomeserverTestCase):
_order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
_order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+ # Note: joined_member counts are sorted in descending order when dir=f
_order_test("joined_members", [room_id_3, room_id_2, room_id_1])
_order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: joined_local_member counts are sorted in descending order when dir=f
_order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
_order_test(
"joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
)
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+ # Note: versions are sorted in descending order when dir=f
+ _order_test("version", sorted_by_room_id_asc, reverse=True)
+ _order_test("version", sorted_by_room_id_desc)
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("creator", sorted_by_room_id_asc)
+ _order_test("creator", sorted_by_room_id_desc, reverse=True)
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("encryption", sorted_by_room_id_asc)
+ _order_test("encryption", sorted_by_room_id_desc, reverse=True)
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("federatable", sorted_by_room_id_asc)
+ _order_test("federatable", sorted_by_room_id_desc, reverse=True)
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+ _order_test("public", sorted_by_room_id_asc)
+ _order_test("public", sorted_by_room_id_desc, reverse=True)
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("join_rules", sorted_by_room_id_asc)
+ _order_test("join_rules", sorted_by_room_id_desc, reverse=True)
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+ _order_test("guest_access", sorted_by_room_id_asc)
+ _order_test("guest_access", sorted_by_room_id_desc, reverse=True)
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
+ _order_test("history_visibility", sorted_by_room_id_asc)
+ _order_test("history_visibility", sorted_by_room_id_desc, reverse=True)
+ # Note: state_event counts are sorted in descending order when dir=f
_order_test("state_events", [room_id_3, room_id_2, room_id_1])
_order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e0b9fe8e9..971140573 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1181,6 +1181,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.other_user, device_id=None, valid_until_ms=None
)
)
+
self.url_prefix = "/_synapse/admin/v2/users/%s"
self.url_other_user = self.url_prefix % self.other_user
@@ -1188,7 +1189,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
channel = self.make_request(
"GET",
@@ -1216,7 +1217,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_synapse/admin/v2/users/@unknown_person:test",
+ self.url_prefix % "@unknown_person:test",
access_token=self.admin_user_tok,
)
@@ -1337,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new admin user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user (server admin)
body = {
@@ -1386,7 +1387,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
Check that a new regular user is created successfully.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -1478,7 +1479,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1515,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Register new user with admin API
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -1545,7 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Check that a new regular user is created successfully and
got an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -1588,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Check that a new regular user is created successfully and
got not an email pusher.
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
body = {
@@ -2085,10 +2086,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
# the user is deactivated, the threepid will be deleted
# Get user
@@ -2101,11 +2105,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self):
"""
@@ -2177,9 +2183,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"localdb_enabled": False}})
def test_reactivate_user_localdb_disabled(self):
"""
@@ -2209,9 +2217,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
@override_config({"password_config": {"enabled": False}})
def test_reactivate_user_password_disabled(self):
"""
@@ -2241,9 +2251,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
@@ -2328,7 +2340,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Ensure an account can't accidentally be deactivated by using a str value
for the deactivated body parameter
"""
- url = "/_synapse/admin/v2/users/@bob:test"
+ url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
@@ -2392,18 +2404,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Deactivate the user.
channel = self.make_request(
"PUT",
- "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ self.url_prefix % urllib.parse.quote(user_id),
access_token=self.admin_user_tok,
content={"deactivated": True},
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
- self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", channel.json_body)
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
@@ -2416,13 +2430,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("admin", content)
self.assertIn("deactivated", content)
self.assertIn("shadow_banned", content)
- self.assertIn("password_hash", content)
self.assertIn("creation_ts", content)
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
self.assertIn("external_ids", content)
+ # This key was removed intentionally. Ensure it is not accidentally re-included.
+ self.assertNotIn("password_hash", content)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c026d526e..c9b220e73 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,6 +21,7 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel
@@ -93,11 +94,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body,
)
- def test_deny_membership(self):
- """Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
- self.assertEquals(400, channel.code, channel.json_body)
-
def test_deny_invalid_event(self):
"""Test that we deny relations on non-existant events"""
channel = self._send_relation(
@@ -459,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_bundled_aggregations(self):
- """Test that annotations, references, and threads get correctly bundled."""
+ """
+ Test that annotations, references, and threads get correctly bundled.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+
+ See test_edit for a similar test for edits.
+ """
# Setup by sending a variety of relations.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -487,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
- def assert_bundle(actual):
+ def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
# Ensure the fields are as expected.
self.assertCountEqual(
- actual.keys(),
+ relations_dict.keys(),
(
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
@@ -508,17 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"type": "m.reaction", "key": "b", "count": 1},
]
},
- actual[RelationTypes.ANNOTATION],
+ relations_dict[RelationTypes.ANNOTATION],
)
self.assertEquals(
{"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- actual[RelationTypes.REFERENCE],
+ relations_dict[RelationTypes.REFERENCE],
)
self.assertEquals(
2,
- actual[RelationTypes.THREAD].get("count"),
+ relations_dict[RelationTypes.THREAD].get("count"),
+ )
+ self.assertTrue(
+ relations_dict[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
@@ -535,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"type": "m.room.test",
"user_id": self.user_id,
},
- actual[RelationTypes.THREAD].get("latest_event"),
+ relations_dict[RelationTypes.THREAD].get("latest_event"),
)
- def _find_and_assert_event(events):
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- break
- else:
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
- assert_bundle(event["unsigned"].get("m.relations"))
-
# Request the event directly.
channel = self.make_request(
"GET",
@@ -556,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body)
# Request the room messages.
channel = self.make_request(
@@ -565,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- _find_and_assert_event(channel.json_body["chunk"])
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
# Request the room context.
channel = self.make_request(
@@ -574,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
+ assert_bundle(channel.json_body["event"])
# Request sync.
- # channel = self.make_request("GET", "/sync", access_token=self.user_token)
- # self.assertEquals(200, channel.code, channel.json_body)
- # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- # self.assertTrue(room_timeline["limited"])
- # _find_and_assert_event(room_timeline["events"])
-
- # Note that /relations is tested separately in test_aggregation_get_event_for_thread
- # since it needs different data configured.
+ channel = self.make_request("GET", "/sync", access_token=self.user_token)
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ self._find_event_in_chunk(room_timeline["events"])
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
@@ -779,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"]
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
-
self.assertEquals(channel.json_body["content"], new_body)
+ assert_bundle(channel.json_body)
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.REPLACE, relations_dict)
-
- m_replace_dict = relations_dict[RelationTypes.REPLACE]
- for key in ["event_id", "sender", "origin_server_ts"]:
- self.assertIn(key, m_replace_dict)
-
- self.assert_dict(
- {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
)
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync, but limit the timeline so it becomes limited (and includes
+ # bundled aggregations).
+ filter = urllib.parse.quote_plus(
+ '{"room": {"timeline": {"limit": 2}}}'.encode()
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
@@ -1104,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
def _send_relation(
self,
relation_type: str,
@@ -1119,7 +1155,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
- content: The content of the created event.
+ content: The content of the created event. Will be modified to configure
+ the m.relates_to key based on the other provided parameters.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
@@ -1130,17 +1167,21 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if not access_token:
access_token = self.user_token
- query = ""
- if key:
- query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
-
original_id = parent_id if parent_id else self.parent_id
+ if content is None:
+ content = {}
+ content["m.relates_to"] = {
+ "event_id": original_id,
+ "rel_type": relation_type,
+ }
+ if key is not None:
+ content["m.relates_to"]["key"] = key
+
channel = self.make_request(
"POST",
- "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
- % (self.room, original_id, relation_type, event_type, query),
- content or {},
+ f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
+ content,
access_token=access_token,
)
return channel
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b58452195..fe5b536d9 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(event)
time_now = self.clock.time_msec()
- serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+ serialized = self.serializer.serialize_event(event, time_now)
return serialized
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 1af5e5cee..842438358 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -196,6 +196,16 @@ class RestHelper:
expect_code=expect_code,
)
+ def ban(self, room: str, src: str, targ: str, **kwargs: object):
+ """A convenience helper: `change_membership` with `membership` preset to "ban"."""
+ self.change_membership(
+ room=room,
+ src=src,
+ targ=targ,
+ membership=Membership.BAN,
+ **kwargs,
+ )
+
def change_membership(
self,
room: str,
diff --git a/tests/server.py b/tests/server.py
index ca2b7a5b9..a0cd14ea4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -14,6 +14,8 @@
import hashlib
import json
import logging
+import os
+import os.path
import time
import uuid
import warnings
@@ -71,6 +73,7 @@ from tests.utils import (
POSTGRES_HOST,
POSTGRES_PASSWORD,
POSTGRES_USER,
+ SQLITE_PERSIST_DB,
USE_POSTGRES_FOR_TESTS,
MockClock,
default_config,
@@ -739,9 +742,23 @@ def setup_test_homeserver(
},
}
else:
+ if SQLITE_PERSIST_DB:
+ # The current working directory is in _trial_temp, so this gets created within that directory.
+ test_db_location = os.path.abspath("test.db")
+ logger.debug("Will persist db to %s", test_db_location)
+ # Ensure each test gets a clean database.
+ try:
+ os.remove(test_db_location)
+ except FileNotFoundError:
+ pass
+ else:
+ logger.debug("Removed existing DB at %s", test_db_location)
+ else:
+ test_db_location = ":memory:"
+
database_config = {
"name": "sqlite3",
- "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+ "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index bf7808486..2bc89512f 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -531,17 +531,25 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert_many(
table="federation_inbound_events_staging",
+ keys=(
+ "origin",
+ "room_id",
+ "received_ts",
+ "event_id",
+ "event_json",
+ "internal_metadata",
+ ),
values=[
- {
- "origin": "some_origin",
- "room_id": room_id,
- "received_ts": 0,
- "event_id": f"$fake_event_id_{i + 1}",
- "event_json": json_encoder.encode(
+ (
+ "some_origin",
+ room_id,
+ 0,
+ f"$fake_event_id_{i + 1}",
+ json_encoder.encode(
{"prev_events": [prev_event_format(f"$fake_event_id_{i}")]}
),
- "internal_metadata": "{}",
- }
+ "{}",
+ )
for i in range(500)
],
desc="test_prune_inbound_federation_queue",
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3eef1c4c0..2b9804aba 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -17,7 +17,9 @@ from unittest.mock import Mock
from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
+from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext
from synapse.types import UserID, create_requester
from synapse.util import Clock
@@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
)
self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
+
+
+class StripUnsignedFromEventsTestCase(unittest.TestCase):
+ def test_strip_unauthorized_unsigned_values(self):
+ event1 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event1:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
+ }
+ filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
+ # Make sure unauthorized fields are stripped from unsigned
+ self.assertNotIn("more warez", filtered_event.unsigned)
+
+ def test_strip_event_maintains_allowed_fields(self):
+ event2 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event2:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "auth_events": [],
+ "content": {"membership": "join"},
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+
+ filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
+ self.assertIn("age", filtered_event2.unsigned)
+ self.assertEqual(14, filtered_event2.unsigned["age"])
+ self.assertNotIn("more warez", filtered_event2.unsigned)
+ # Invite_room_state is allowed in events of type m.room.member
+ self.assertIn("invite_room_state", filtered_event2.unsigned)
+ self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
+
+ def test_strip_event_removes_fields_based_on_event_type(self):
+ event3 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event3:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.power_levels",
+ "origin": "test.servx",
+ "content": {},
+ "auth_events": [],
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+ filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
+ self.assertIn("age", filtered_event3.unsigned)
+ # Invite_room_state field is only permitted in event type m.room.member
+ self.assertNotIn("invite_room_state", filtered_event3.unsigned)
+ self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 15ac2bfeb..f05a373aa 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -19,7 +19,7 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Any, Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, TypeVar
from unittest.mock import Mock
import attr
@@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-def make_awaitable(result: Any) -> Awaitable[Any]:
+def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
diff --git a/tests/utils.py b/tests/utils.py
index 6d013e851..c06fc320f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -42,6 +42,10 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
+# When debugging a specific test, it's occasionally useful to write the
+# DB to disk and query it with the sqlite CLI.
+SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
+
# the dbname we will connect to in order to create the base database.
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"