diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml
new file mode 100644
index 000000000..a35fec394
--- /dev/null
+++ b/.buildkite/postgres-config.yaml
@@ -0,0 +1,21 @@
+# Configuration file used for testing the 'synapse_port_db' script.
+# Tells the script to connect to the postgresql database that will be available in the
+# CI's Docker setup at the point where this file is considered.
+server_name: "test"
+
+signing_key_path: "/src/.buildkite/test.signing.key"
+
+report_stats: false
+
+database:
+ name: "psycopg2"
+ args:
+ user: postgres
+ host: postgres
+ password: postgres
+ database: synapse
+
+# Suppress the key server warning.
+trusted_key_servers:
+ - server_name: "matrix.org"
+suppress_key_server_warning: true
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py
new file mode 100755
index 000000000..df6082b0a
--- /dev/null
+++ b/.buildkite/scripts/create_postgres_db.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from synapse.storage.engines import create_engine
+
+logger = logging.getLogger("create_postgres_db")
+
+if __name__ == "__main__":
+ # Create a PostgresEngine.
+ db_engine = create_engine({"name": "psycopg2", "args": {}})
+
+ # Connect to postgres to create the base database.
+ # We use "postgres" as a database because it's bound to exist and the "synapse" one
+ # doesn't exist yet.
+ db_conn = db_engine.module.connect(
+ user="postgres", host="postgres", password="postgres", dbname="postgres"
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+ cur.execute("CREATE DATABASE synapse;")
+ cur.close()
+ db_conn.close()
diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh
new file mode 100755
index 000000000..9ed217763
--- /dev/null
+++ b/.buildkite/scripts/test_synapse_port_db.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+#
+# Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along
+# with additional dependencies needed for the test (such as coverage or the PostgreSQL
+# driver), update the schema of the test SQLite database and run background updates on it,
+# create an empty test database in PostgreSQL, then run the 'synapse_port_db' script to
+# test porting the SQLite database to the PostgreSQL database (with coverage).
+
+set -xe
+cd `dirname $0`/../..
+
+echo "--- Install dependencies"
+
+# Install dependencies for this test.
+pip install psycopg2 coverage coverage-enable-subprocess
+
+# Install Synapse itself. This won't update any libraries.
+pip install -e .
+
+echo "--- Generate the signing key"
+
+# Generate the server's signing key.
+python -m synapse.app.homeserver --generate-keys -c .buildkite/sqlite-config.yaml
+
+echo "--- Prepare the databases"
+
+# Make sure the SQLite3 database is using the latest schema and has no pending background update.
+scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml
+
+# Create the PostgreSQL database.
+./.buildkite/scripts/create_postgres_db.py
+
+echo "+++ Run synapse_port_db"
+
+# Run the script
+coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml
diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml
new file mode 100644
index 000000000..635b92176
--- /dev/null
+++ b/.buildkite/sqlite-config.yaml
@@ -0,0 +1,18 @@
+# Configuration file used for testing the 'synapse_port_db' script.
+# Tells the 'update_database' script to connect to the test SQLite database to upgrade its
+# schema and run background updates on it.
+server_name: "test"
+
+signing_key_path: "/src/.buildkite/test.signing.key"
+
+report_stats: false
+
+database:
+ name: "sqlite3"
+ args:
+ database: ".buildkite/test_db.db"
+
+# Suppress the key server warning.
+trusted_key_servers:
+ - server_name: "matrix.org"
+suppress_key_server_warning: true
diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db
new file mode 100644
index 000000000..f20567ba7
Binary files /dev/null and b/.buildkite/test_db.db differ
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 1ead0d003..8939fda67 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,3 +5,4 @@
* [ ] Pull request is based on the develop branch
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
+* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style))
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index a71a4a696..df81f6e54 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
+To facilitate meeting these criteria you can run ``scripts-dev/lint.sh``
+locally. Since this runs the tools listed in the above document, you'll need
+python 3.6 and to install each tool. **Note that the script does not just
+test/check, but also reformats code, so you may wish to ensure any new code is
+committed first**. By default this script checks all files and can take some
+time; if you alter only certain files, you might wish to specify paths as
+arguments to reduce the run-time.
+
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
+Before doing a commit, ensure the changes you've made don't produce
+linting errors. You can do this by running the linters as follows. Ensure to
+commit any files that were corrected.
+
+::
+ # Install the dependencies
+ pip install -U black flake8 isort
+
+ # Run the linter script
+ ./scripts-dev/lint.sh
+
Changelog
~~~~~~~~~
diff --git a/INSTALL.md b/INSTALL.md
index 69e423923..e7b429c05 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -413,16 +413,18 @@ For a more detailed guide to configuring your server for federation, see
## Email
-It is desirable for Synapse to have the capability to send email. For example,
-this is required to support the 'password reset' feature.
+It is desirable for Synapse to have the capability to send email. This allows
+Synapse to send password reset emails, send verifications when an email address
+is added to a user's account, and send email notifications to users when they
+receive new messages.
To configure an SMTP server for Synapse, modify the configuration section
-headed ``email``, and be sure to have at least the ``smtp_host``, ``smtp_port``
-and ``notif_from`` fields filled out. You may also need to set ``smtp_user``,
-``smtp_pass``, and ``require_transport_security``.
+headed `email`, and be sure to have at least the `smtp_host`, `smtp_port`
+and `notif_from` fields filled out. You may also need to set `smtp_user`,
+`smtp_pass`, and `require_transport_security`.
-If Synapse is not configured with an SMTP server, password reset via email will
- be disabled by default.
+If email is not configured, password reset, registration and notifications via
+email will be disabled.
## Registering a user
diff --git a/changelog.d/5727.feature b/changelog.d/5727.feature
new file mode 100644
index 000000000..819bebf2d
--- /dev/null
+++ b/changelog.d/5727.feature
@@ -0,0 +1 @@
+Add federation support for cross-signing.
diff --git a/changelog.d/6140.misc b/changelog.d/6140.misc
new file mode 100644
index 000000000..0feb97ec6
--- /dev/null
+++ b/changelog.d/6140.misc
@@ -0,0 +1 @@
+Add a CI job to test the `synapse_port_db` script.
\ No newline at end of file
diff --git a/changelog.d/6164.doc b/changelog.d/6164.doc
new file mode 100644
index 000000000..f9395b02b
--- /dev/null
+++ b/changelog.d/6164.doc
@@ -0,0 +1 @@
+Contributor documentation now mentions script to run linters.
diff --git a/changelog.d/6218.misc b/changelog.d/6218.misc
new file mode 100644
index 000000000..49d10c36c
--- /dev/null
+++ b/changelog.d/6218.misc
@@ -0,0 +1 @@
+Convert EventContext to an attrs.
diff --git a/changelog.d/6232.bugfix b/changelog.d/6232.bugfix
new file mode 100644
index 000000000..12718ba93
--- /dev/null
+++ b/changelog.d/6232.bugfix
@@ -0,0 +1 @@
+Remove a room from a server's public rooms list on room upgrade.
\ No newline at end of file
diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature
new file mode 100644
index 000000000..d225ac33b
--- /dev/null
+++ b/changelog.d/6238.feature
@@ -0,0 +1 @@
+Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
diff --git a/changelog.d/6240.misc b/changelog.d/6240.misc
new file mode 100644
index 000000000..0b3d7a14a
--- /dev/null
+++ b/changelog.d/6240.misc
@@ -0,0 +1 @@
+Move `persist_events` out from main data store.
diff --git a/changelog.d/6250.misc b/changelog.d/6250.misc
new file mode 100644
index 000000000..12e3fe66b
--- /dev/null
+++ b/changelog.d/6250.misc
@@ -0,0 +1 @@
+Reduce verbosity of user/room stats.
diff --git a/changelog.d/6251.misc b/changelog.d/6251.misc
new file mode 100644
index 000000000..371c6983b
--- /dev/null
+++ b/changelog.d/6251.misc
@@ -0,0 +1 @@
+Reduce impact of debug logging.
diff --git a/changelog.d/6253.bugfix b/changelog.d/6253.bugfix
new file mode 100644
index 000000000..266fae381
--- /dev/null
+++ b/changelog.d/6253.bugfix
@@ -0,0 +1 @@
+Delete keys from key backup when deleting backup versions.
diff --git a/changelog.d/6254.bugfix b/changelog.d/6254.bugfix
new file mode 100644
index 000000000..3181484b8
--- /dev/null
+++ b/changelog.d/6254.bugfix
@@ -0,0 +1 @@
+Make notification of cross-signing signatures work with workers.
diff --git a/changelog.d/6257.doc b/changelog.d/6257.doc
new file mode 100644
index 000000000..e985afde0
--- /dev/null
+++ b/changelog.d/6257.doc
@@ -0,0 +1 @@
+Modify CAPTCHA_SETUP.md to update the terms `private key` and `public key` to `secret key` and `site key` respectively. Contributed by Yash Jipkate.
diff --git a/changelog.d/6259.misc b/changelog.d/6259.misc
new file mode 100644
index 000000000..3ff81b1ac
--- /dev/null
+++ b/changelog.d/6259.misc
@@ -0,0 +1 @@
+Expose some homeserver functionality to spam checkers.
diff --git a/changelog.d/6263.misc b/changelog.d/6263.misc
new file mode 100644
index 000000000..7b1bb4b67
--- /dev/null
+++ b/changelog.d/6263.misc
@@ -0,0 +1 @@
+Change cache descriptors to always return deferreds.
diff --git a/changelog.d/6269.misc b/changelog.d/6269.misc
new file mode 100644
index 000000000..9fd333cc8
--- /dev/null
+++ b/changelog.d/6269.misc
@@ -0,0 +1 @@
+Fix incorrect comment regarding the functionality of an `if` statement.
\ No newline at end of file
diff --git a/changelog.d/6270.misc b/changelog.d/6270.misc
new file mode 100644
index 000000000..d1c581132
--- /dev/null
+++ b/changelog.d/6270.misc
@@ -0,0 +1 @@
+Update CI to run `isort` over the `scripts` and `scripts-dev` directories.
\ No newline at end of file
diff --git a/changelog.d/6271.misc b/changelog.d/6271.misc
new file mode 100644
index 000000000..236976027
--- /dev/null
+++ b/changelog.d/6271.misc
@@ -0,0 +1 @@
+Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.
\ No newline at end of file
diff --git a/changelog.d/6272.doc b/changelog.d/6272.doc
new file mode 100644
index 000000000..232180bcd
--- /dev/null
+++ b/changelog.d/6272.doc
@@ -0,0 +1 @@
+Update `INSTALL.md` Email section to talk about `account_threepid_delegates`.
\ No newline at end of file
diff --git a/changelog.d/6273.doc b/changelog.d/6273.doc
new file mode 100644
index 000000000..21a41d987
--- /dev/null
+++ b/changelog.d/6273.doc
@@ -0,0 +1 @@
+Fix a small typo in `account_threepid_delegates` configuration option.
\ No newline at end of file
diff --git a/changelog.d/6274.misc b/changelog.d/6274.misc
new file mode 100644
index 000000000..eb4966124
--- /dev/null
+++ b/changelog.d/6274.misc
@@ -0,0 +1 @@
+Port replication http server endpoints to async/await.
diff --git a/changelog.d/6275.misc b/changelog.d/6275.misc
new file mode 100644
index 000000000..f57e2c4ad
--- /dev/null
+++ b/changelog.d/6275.misc
@@ -0,0 +1 @@
+Port room rest handlers to async/await.
diff --git a/changelog.d/6276.misc b/changelog.d/6276.misc
new file mode 100644
index 000000000..4a4428251
--- /dev/null
+++ b/changelog.d/6276.misc
@@ -0,0 +1 @@
+Add a CI job to test the `synapse_port_db` script.
diff --git a/changelog.d/6277.misc b/changelog.d/6277.misc
new file mode 100644
index 000000000..490713577
--- /dev/null
+++ b/changelog.d/6277.misc
@@ -0,0 +1 @@
+Remove redundant CLI parameters on CI's `flake8` step.
\ No newline at end of file
diff --git a/changelog.d/6278.bugfix b/changelog.d/6278.bugfix
new file mode 100644
index 000000000..c10727046
--- /dev/null
+++ b/changelog.d/6278.bugfix
@@ -0,0 +1 @@
+Fix exception when remote servers attempt to join a room that they're not allowed to join.
diff --git a/changelog.d/6279.misc b/changelog.d/6279.misc
new file mode 100644
index 000000000..5f5144a9e
--- /dev/null
+++ b/changelog.d/6279.misc
@@ -0,0 +1 @@
+Port `federation_server.py` to async/await.
diff --git a/changelog.d/6280.misc b/changelog.d/6280.misc
new file mode 100644
index 000000000..96a0eb21b
--- /dev/null
+++ b/changelog.d/6280.misc
@@ -0,0 +1 @@
+Port receipt and read markers to async/wait.
diff --git a/changelog.d/6284.bugfix b/changelog.d/6284.bugfix
new file mode 100644
index 000000000..cf15053d2
--- /dev/null
+++ b/changelog.d/6284.bugfix
@@ -0,0 +1 @@
+Prevent errors from appearing on Synapse startup if `git` is not installed.
\ No newline at end of file
diff --git a/changelog.d/6291.misc b/changelog.d/6291.misc
new file mode 100644
index 000000000..7b1bb4b67
--- /dev/null
+++ b/changelog.d/6291.misc
@@ -0,0 +1 @@
+Change cache descriptors to always return deferreds.
diff --git a/changelog.d/6294.misc b/changelog.d/6294.misc
new file mode 100644
index 000000000..a3e6b8296
--- /dev/null
+++ b/changelog.d/6294.misc
@@ -0,0 +1 @@
+Split out state storage into separate data store.
diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc
new file mode 100644
index 000000000..d4190730b
--- /dev/null
+++ b/changelog.d/6298.misc
@@ -0,0 +1 @@
+Refactor EventContext for clarity.
\ No newline at end of file
diff --git a/changelog.d/6300.misc b/changelog.d/6300.misc
new file mode 100644
index 000000000..0b3d7a14a
--- /dev/null
+++ b/changelog.d/6300.misc
@@ -0,0 +1 @@
+Move `persist_events` out from main data store.
diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature
new file mode 100644
index 000000000..78a187a1d
--- /dev/null
+++ b/changelog.d/6301.feature
@@ -0,0 +1 @@
+Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
diff --git a/changelog.d/6304.misc b/changelog.d/6304.misc
new file mode 100644
index 000000000..20372b4f7
--- /dev/null
+++ b/changelog.d/6304.misc
@@ -0,0 +1 @@
+Update the version of black used to 19.10b0.
diff --git a/changelog.d/6305.misc b/changelog.d/6305.misc
new file mode 100644
index 000000000..f047fc306
--- /dev/null
+++ b/changelog.d/6305.misc
@@ -0,0 +1 @@
+Add some documentation about worker replication.
diff --git a/changelog.d/6306.bugfix b/changelog.d/6306.bugfix
new file mode 100644
index 000000000..c7dcbcdce
--- /dev/null
+++ b/changelog.d/6306.bugfix
@@ -0,0 +1 @@
+Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash.
diff --git a/changelog.d/6307.bugfix b/changelog.d/6307.bugfix
new file mode 100644
index 000000000..f2917c505
--- /dev/null
+++ b/changelog.d/6307.bugfix
@@ -0,0 +1 @@
+Fix `/purge_room` admin API.
diff --git a/changelog.d/6312.misc b/changelog.d/6312.misc
new file mode 100644
index 000000000..55e3e1654
--- /dev/null
+++ b/changelog.d/6312.misc
@@ -0,0 +1 @@
+Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only.
diff --git a/changelog.d/6313.bugfix b/changelog.d/6313.bugfix
new file mode 100644
index 000000000..f4d4a97f0
--- /dev/null
+++ b/changelog.d/6313.bugfix
@@ -0,0 +1 @@
+Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0.
diff --git a/changelog.d/6314.misc b/changelog.d/6314.misc
new file mode 100644
index 000000000..236976027
--- /dev/null
+++ b/changelog.d/6314.misc
@@ -0,0 +1 @@
+Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.
\ No newline at end of file
diff --git a/changelog.d/6318.misc b/changelog.d/6318.misc
new file mode 100644
index 000000000..63527ccef
--- /dev/null
+++ b/changelog.d/6318.misc
@@ -0,0 +1 @@
+Remove the dependency on psutil and replace functionality with the stdlib `resource` module.
diff --git a/changelog.d/6319.misc b/changelog.d/6319.misc
new file mode 100644
index 000000000..9711ef21e
--- /dev/null
+++ b/changelog.d/6319.misc
@@ -0,0 +1 @@
+Improve documentation for EventContext fields.
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 6b22400a6..3bbbcfa1b 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -78,7 +78,7 @@ class InputOutput(object):
m = re.match("^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("%s joining %s" % (self.user, room_name))
self.server.join_room(room_name, self.user, self.user)
# self.print_line("OK.")
@@ -105,7 +105,7 @@ class InputOutput(object):
m = re.match("^backfill (\S+)$", line)
if m:
# we want to backfill a room
- room_name, = m.groups()
+ (room_name,) = m.groups()
self.print_line("backfill %s" % room_name)
self.server.backfill(room_name)
return
diff --git a/docker/README.md b/docker/README.md
index 4b712f3f5..24dfa77dc 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -101,7 +101,7 @@ is suitable for local testing, but for any practical use, you will either need
to use a reverse proxy, or configure Synapse to expose an HTTPS port.
For documentation on using a reverse proxy, see
-https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
+https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md.
For more information on enabling TLS support in synapse itself, see
https://github.com/matrix-org/synapse/blob/master/INSTALL.md#tls-certificates. Of
diff --git a/docker/start.py b/docker/start.py
index e41ea20e7..6e1cb807a 100755
--- a/docker/start.py
+++ b/docker/start.py
@@ -217,8 +217,9 @@ def main(args, environ):
# backwards-compatibility generate-a-config-on-the-fly mode
if "SYNAPSE_CONFIG_PATH" in environ:
error(
- "SYNAPSE_SERVER_NAME and SYNAPSE_CONFIG_PATH are mutually exclusive "
- "except in `generate` or `migrate_config` mode."
+ "SYNAPSE_SERVER_NAME can only be combined with SYNAPSE_CONFIG_PATH "
+ "in `generate` or `migrate_config` mode. To start synapse using a "
+ "config file, unset the SYNAPSE_SERVER_NAME environment variable."
)
config_path = "/compiled/homeserver.yaml"
diff --git a/docs/CAPTCHA_SETUP.md b/docs/CAPTCHA_SETUP.md
index 5f9057530..331e5d059 100644
--- a/docs/CAPTCHA_SETUP.md
+++ b/docs/CAPTCHA_SETUP.md
@@ -4,7 +4,7 @@ The captcha mechanism used is Google's ReCaptcha. This requires API keys from Go
## Getting keys
-Requires a public/private key pair from:
+Requires a site/secret key pair from:
@@ -15,8 +15,8 @@ Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option
The keys are a config option on the home server config. If they are not
visible, you can generate them via `--generate-config`. Set the following value:
- recaptcha_public_key: YOUR_PUBLIC_KEY
- recaptcha_private_key: YOUR_PRIVATE_KEY
+ recaptcha_public_key: YOUR_SITE_KEY
+ recaptcha_private_key: YOUR_SECRET_KEY
In addition, you MUST enable captchas via:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 6c81c0db7..d2f4aff82 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -955,7 +955,7 @@ uploads_path: "DATADIR/uploads"
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
- #email: https://example.com # Delegate email sending to example.org
+ #email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index e099d8a87..ba9e874d0 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -199,7 +199,20 @@ client (C):
#### REPLICATE (C)
- Asks the server to replicate a given stream
+Asks the server to replicate a given stream. The syntax is:
+
+```
+ REPLICATE
+```
+
+Where `` may be either:
+ * a numeric stream_id to stream updates since (exclusive)
+ * `NOW` to stream all subsequent updates.
+
+The `` is the name of a replication stream to subscribe
+to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
+of streams). It can also be `ALL` to subscribe to all known streams,
+in which case the `` must be set to `NOW`.
#### USER_SYNC (C)
diff --git a/mypy.ini b/mypy.ini
index ffadaddc0..1d77c0ecc 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,8 +1,11 @@
[mypy]
-namespace_packages=True
-plugins=mypy_zope:plugin
-follow_imports=skip
-mypy_path=stubs
+namespace_packages = True
+plugins = mypy_zope:plugin
+follow_imports = normal
+check_untyped_defs = True
+show_error_codes = True
+show_traceback = True
+mypy_path = stubs
[mypy-zope]
ignore_missing_imports = True
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 02a2ca39e..34c4854e1 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -7,7 +7,15 @@
set -e
-isort -y -rc synapse tests scripts-dev scripts
-flake8 synapse tests
-python3 -m black synapse tests scripts-dev scripts
+if [ $# -ge 1 ]
+then
+ files=$*
+else
+ files="synapse tests scripts-dev scripts"
+fi
+
+echo "Linting these locations: $files"
+isort -y -rc $files
+flake8 $files
+python3 -m black $files
./scripts-dev/config-lint.sh
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
new file mode 100755
index 000000000..27a1ad1e7
--- /dev/null
+++ b/scripts-dev/update_database
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import sys
+
+import yaml
+
+from twisted.internet import defer, reactor
+
+from synapse.config.homeserver import HomeServerConfig
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger("update_database")
+
+
+class MockHomeserver(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+ def __init__(self, config, database_engine, db_conn, **kwargs):
+ super(MockHomeserver, self).__init__(
+ config.server_name,
+ reactor=reactor,
+ config=config,
+ database_engine=database_engine,
+ **kwargs
+ )
+
+ self.database_engine = database_engine
+ self.db_conn = db_conn
+
+ def get_db_conn(self):
+ return self.db_conn
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=(
+ "Updates a synapse database to the latest schema and runs background updates"
+ " on it."
+ )
+ )
+ parser.add_argument("-v", action='store_true')
+ parser.add_argument(
+ "--database-config",
+ type=argparse.FileType('r'),
+ required=True,
+ help="A database config file for either a SQLite3 database or a PostgreSQL one.",
+ )
+
+ args = parser.parse_args()
+
+ logging_config = {
+ "level": logging.DEBUG if args.v else logging.INFO,
+ "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ }
+
+ logging.basicConfig(**logging_config)
+
+ # Load, process and sanity-check the config.
+ hs_config = yaml.safe_load(args.database_config)
+
+ if "database" not in hs_config:
+ sys.stderr.write("The configuration file must have a 'database' section.\n")
+ sys.exit(4)
+
+ config = HomeServerConfig()
+ config.parse_config_dict(hs_config, "", "")
+
+ # Create the database engine and a connection to it.
+ database_engine = create_engine(config.database_config)
+ db_conn = database_engine.module.connect(
+ **{
+ k: v
+ for k, v in config.database_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ )
+
+ # Update the database to the latest schema.
+ prepare_database(db_conn, database_engine, config=config)
+ db_conn.commit()
+
+ # Instantiate and initialise the homeserver object.
+ hs = MockHomeserver(
+ config,
+ database_engine,
+ db_conn,
+ db_config=config.database_config,
+ )
+ # setup instantiates the store within the homeserver object.
+ hs.setup()
+ store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def run_background_updates():
+ yield store.run_background_updates(sleep=False)
+ # Stop the reactor to exit the script once every background update is run.
+ reactor.stop()
+
+ # Apply all background updates on the database.
+ reactor.callWhenRunning(lambda: run_as_background_process(
+ "background_updates", run_background_updates
+ ))
+
+ reactor.run()
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index 12747c602..b5b63933a 100755
--- a/scripts/move_remote_media_to_new_store.py
+++ b/scripts/move_remote_media_to_new_store.py
@@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
- logger.warn(
+ logger.warning(
"Original for %s/%s (%s) does not exist",
origin_server,
file_id,
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 54faed1e8..0d3321682 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -157,7 +157,7 @@ class Store(
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
- logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
@@ -432,7 +432,7 @@ class Porter(object):
for row in rows:
d = dict(zip(headers, row))
if "\0" in d['value']:
- logger.warn('dropping search row %s', d)
+ logger.warning('dropping search row %s', d)
else:
rows_dict.append(d)
@@ -647,7 +647,7 @@ class Porter(object):
if isinstance(col, bytes):
return bytearray(col)
elif isinstance(col, string_types) and "\0" in col:
- logger.warn(
+ logger.warning(
"DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 53f3bb0fa..5d0b7d280 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -497,7 +497,7 @@ class Auth(object):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
- logger.warn("Unrecognised appservice access token.")
+ logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 312196675..49c4b8505 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -138,3 +138,10 @@ class LimitBlockingTypes(object):
MONTHLY_ACTIVE_USER = "monthly_active_user"
HS_DISABLED = "hs_disabled"
+
+
+class EventContentFields(object):
+ """Fields found in events' content, regardless of type."""
+
+ # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
+ LABELS = "org.matrix.labels"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9f06556bd..bec13f08d 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -20,6 +20,7 @@ from jsonschema import FormatChecker
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import RoomID, UserID
@@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
+ # Include or exclude events with the provided labels.
+ # cf https://github.com/matrix-org/matrix-doc/pull/2326
+ "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
+ "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
},
}
@@ -259,6 +264,9 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None)
+ self.labels = self.filter_json.get("org.matrix.labels", None)
+ self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+
def filters_all_types(self):
return "*" in self.not_types
@@ -282,6 +290,7 @@ class Filter(object):
room_id = None
ev_type = "m.presence"
contains_url = False
+ labels = []
else:
sender = event.get("sender", None)
if not sender:
@@ -300,10 +309,11 @@ class Filter(object):
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
+ labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(room_id, sender, ev_type, contains_url)
+ return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, contains_url):
+ def check_fields(self, room_id, sender, event_type, labels, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -313,6 +323,7 @@ class Filter(object):
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v),
+ "labels": lambda v: v in labels,
}
for name, match_func in literal_keys.items():
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index d877c7783..a01bac299 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses):
bind_addresses (list): Addresses on which the service listens.
"""
if address == "0.0.0.0" and "::" in bind_addresses:
- logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]")
+ logger.warning(
+ "Failed to listen on 0.0.0.0, continuing because listening on [::]"
+ )
else:
raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 767b87d2d..02b900f38 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -94,7 +94,7 @@ class AppserviceServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -103,7 +103,7 @@ class AppserviceServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dbcc414c4..dadb487d5 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index f20d810ec..d110599a3 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 1ef027a88..418c08625 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 04fbb407a..139221ad3 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 9504bfbc7..e647459d0 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index eb54f5685..00a7f8330 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,12 +19,13 @@ from __future__ import print_function
import gc
import logging
+import math
import os
+import resource
import sys
from six import iteritems
-import psutil
from prometheus_client import Gauge
from twisted.application import service
@@ -282,7 +283,7 @@ class SynapseHomeServer(HomeServer):
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -291,7 +292,7 @@ class SynapseHomeServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
@@ -471,6 +472,87 @@ class SynapseService(service.Service):
return self._port.stopListening()
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+
+@defer.inlineCallbacks
+def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = yield hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ for name, count in iteritems(daily_user_type_results):
+ stats["daily_user_type_" + name] = count
+
+ room_count = yield hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in iteritems(r30_results):
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ #
+ # Performance statistics
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # Database version
+ #
+
+ stats["database_engine"] = hs.get_datastore().database_engine_name
+ stats["database_server_version"] = hs.get_datastore().get_server_version()
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ yield hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
@@ -497,91 +579,19 @@ def run(hs):
reactor.run = profile(reactor.run)
clock = hs.get_clock()
- start_time = clock.time()
stats = {}
- # Contains the list of processes we will be monitoring
- # currently either 0 or 1
- stats_process = []
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)))
+ )
def start_phone_stats_home():
- return run_as_background_process("phone_stats_home", phone_stats_home)
-
- @defer.inlineCallbacks
- def phone_stats_home():
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - start_time)
- if uptime < 0:
- uptime = 0
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
+ return run_as_background_process(
+ "phone_stats_home", phone_stats_home, hs, stats
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
-
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
- stats["daily_user_type_" + name] = count
-
- room_count = yield hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
- stats[
- "daily_active_rooms"
- ] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
-
- r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
- stats["r30_users_" + name] = count
-
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = CACHE_SIZE_FACTOR
- stats["event_cache_size"] = hs.config.event_cache_size
-
- if len(stats_process) > 0:
- stats["memory_rss"] = 0
- stats["cpu_average"] = 0
- for process in stats_process:
- stats["memory_rss"] += process.memory_info().rss
- stats["cpu_average"] += int(process.cpu_percent(interval=None))
-
- stats["database_engine"] = hs.get_datastore().database_engine_name
- stats["database_server_version"] = hs.get_datastore().get_server_version()
- logger.info(
- "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
- )
- try:
- yield hs.get_simple_http_client().put_json(
- hs.config.report_stats_endpoint, stats
- )
- except Exception as e:
- logger.warn("Error reporting stats: %s", e)
-
- def performance_stats_init():
- try:
- process = psutil.Process()
- # Ensure we can fetch both, and make the initial request for cpu_percent
- # so the next request will use this as the initial point.
- process.memory_info().rss
- process.cpu_percent(interval=None)
- logger.info("report_stats can use psutil")
- stats_process.append(process)
- except (AttributeError):
- logger.warning("Unable to read memory/cpu stats. Disabling reporting.")
def generate_user_daily_visit_stats():
return run_as_background_process(
@@ -626,7 +636,7 @@ def run(hs):
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
- clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
+ clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
# We need to defer this init for the cases that we daemonize
# otherwise the process ID we get is that of the non-daemon process
@@ -634,7 +644,7 @@ def run(hs):
# We wait 5 minutes to send the first set of stats as the server can
# be quite busy the first few minutes
- clock.call_later(5 * 60, start_phone_stats_home)
+ clock.call_later(5 * 60, start_phone_stats_home, hs, stats)
_base.start_reactor(
"synapse-homeserver",
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 6bc7202f3..2c6dd3ef0 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index d84732ee3..01a5ffc36 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -114,7 +114,7 @@ class PusherServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -123,7 +123,7 @@ class PusherServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 6a7e2fa70..b14da09f4 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index a5d6dc791..6cb100319 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer):
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(
+ logger.warning(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
@@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer):
else:
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
+ logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 33b357942..aea3985a5 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -94,7 +94,9 @@ class ApplicationService(object):
ip_range_whitelist=None,
):
self.token = token
- self.url = url
+ self.url = (
+ url.rstrip("/") if isinstance(url, str) else None
+ ) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ec5d430af..52ff1b262 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -125,7 +125,7 @@ class KeyConfig(Config):
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
- logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
+ logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
key_servers = [{"server_name": "matrix.org"}]
else:
key_servers = config.get("trusted_key_servers", [])
@@ -156,7 +156,7 @@ class KeyConfig(Config):
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
- logger.warn("Config is missing macaroon_secret_key")
+ logger.warning("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index be92e33f9..75bb90471 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("")
if not log_config:
- logger.warn("Reloaded a blank config?")
+ logger.warning("Reloaded a blank config?")
logging.config.dictConfig(log_config)
@@ -234,8 +234,8 @@ def setup_logging(
# make sure that the first thing we log is a thing we can grep backwards
# for
- logging.warn("***** STARTING SERVER *****")
- logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
+ logging.warning("***** STARTING SERVER *****")
+ logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
return logger
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ab41623b2..1f6dac69d 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -300,7 +300,7 @@ class RegistrationConfig(Config):
# If a delegate is specified, the config option public_baseurl must also be filled out.
#
account_threepid_delegates:
- #email: https://example.com # Delegate email sending to example.org
+ #email: https://example.com # Delegate email sending to example.com
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 694fb2c81..ccaa8a992 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -125,9 +125,11 @@ def compute_event_signature(event_dict, signature_name, signing_key):
redact_json = prune_event_dict(event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
- logger.debug("Signing event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key)
- logger.debug("Signed event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"]
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e7b722547..ec3243b27 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
- logger.warn("Trusting event: %s", event.event_id)
+ logger.warning("Trusting event: %s", event.event_id)
return
if event.type == EventTypes.Create:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index acbcbeece..5f07f6fe4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,104 +12,103 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
from six import iteritems
+import attr
from frozendict import frozendict
from twisted.internet import defer
+from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background
-class EventContext(object):
+@attr.s(slots=True)
+class EventContext:
"""
+ Holds information relevant to persisting an event
+
Attributes:
- state_group (int|None): state group id, if the state has been stored
- as a state group. This is usually only None if e.g. the event is
- an outlier.
- rejected (bool|str): A rejection reason if the event was rejected, else
- False
+ rejected: A rejection reason if the event was rejected, else False
- push_actions (list[(str, list[object])]): list of (user_id, actions)
- tuples
+ state_group: The ID of the state group for this event. Note that state events
+ are persisted with a state group which includes the new event, so this is
+ effectively the state *after* the event in question.
- prev_group (int): Previously persisted state group. ``None`` for an
- outlier.
- delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
- (type, state_key) -> event_id. ``None`` for an outlier.
+ For a *rejected* state event, where the state of the rejected event is
+ ignored, this state_group should never make it into the
+ event_to_state_groups table. Indeed, inspecting this value for a rejected
+ state event is almost certainly incorrect.
- prev_state_events (?): XXX: is this ever set to anything other than
- the empty list?
+ For an outlier, where we don't have the state at the event, this will be
+ None.
+
+ prev_group: If it is known, ``state_group``'s prev_group. Note that this being
+ None does not necessarily mean that ``state_group`` does not have
+ a prev_group!
+
+ If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
+ will always also be ``None``.
+
+ Note that this *not* (necessarily) the state group associated with
+ ``_prev_state_ids``.
+
+ delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
+ and ``state_group``.
+
+ app_service: If this event is being sent by a (local) application service, that
+ app service.
+
+ _current_state_ids: The room state map, including this event - ie, the state
+ in ``state_group``.
- _current_state_ids (dict[(str, str), str]|None):
- The current state map including the current event. None if outlier
- or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
- _prev_state_ids (dict[(str, str), str]|None):
- The current state map excluding the current event. None if outlier
- or we haven't fetched the state from DB yet.
+ FIXME: what is this for an outlier? it seems ill-defined. It seems like
+ it could be either {}, or the state we were given by the remote
+ server, depending on $THINGS
+
+ Note that this is a private attribute: it should be accessed via
+ ``get_current_state_ids``. _AsyncEventContext impl calculates this
+ on-demand: it will be None until that happens.
+
+ _prev_state_ids: The room state map, excluding this event. For a non-state
+ event, this will be the same as _current_state_events.
+
+ Note that it is a completely different thing to prev_group!
+
(type, state_key) -> event_id
- _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
- been calculated. None if we haven't started calculating yet
+ FIXME: again, what is this for an outlier?
- _event_type (str): The type of the event the context is associated with.
- Only set when state has not been fetched yet.
-
- _event_state_key (str|None): The state_key of the event the context is
- associated with. Only set when state has not been fetched yet.
-
- _prev_state_id (str|None): If the event associated with the context is
- a state event, then `_prev_state_id` is the event_id of the state
- that was replaced.
- Only set when state has not been fetched yet.
+ As with _current_state_ids, this is a private attribute. It should be
+ accessed via get_prev_state_ids.
"""
- __slots__ = [
- "state_group",
- "rejected",
- "prev_group",
- "delta_ids",
- "prev_state_events",
- "app_service",
- "_current_state_ids",
- "_prev_state_ids",
- "_prev_state_id",
- "_event_type",
- "_event_state_key",
- "_fetching_state_deferred",
- ]
+ rejected = attr.ib(default=False, type=Union[bool, str])
+ state_group = attr.ib(default=None, type=Optional[int])
+ prev_group = attr.ib(default=None, type=Optional[int])
+ delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ app_service = attr.ib(default=None, type=Optional[ApplicationService])
- def __init__(self):
- self.prev_state_events = []
- self.rejected = False
- self.app_service = None
+ _current_state_ids = attr.ib(
+ default=None, type=Optional[Dict[Tuple[str, str], str]]
+ )
+ _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
@staticmethod
def with_state(
state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
):
- context = EventContext()
-
- # The current state including the current event
- context._current_state_ids = current_state_ids
- # The current state excluding the current event
- context._prev_state_ids = prev_state_ids
- context.state_group = state_group
-
- context._prev_state_id = None
- context._event_type = None
- context._event_state_key = None
- context._fetching_state_deferred = defer.succeed(None)
-
- # A previously persisted state group and a delta between that
- # and this state.
- context.prev_group = prev_group
- context.delta_ids = delta_ids
-
- return context
+ return EventContext(
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ state_group=state_group,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
@defer.inlineCallbacks
def serialize(self, event, store):
@@ -141,7 +140,6 @@ class EventContext(object):
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}
@@ -157,24 +155,17 @@ class EventContext(object):
Returns:
EventContext
"""
- context = EventContext()
-
- # We use the state_group and prev_state_id stuff to pull the
- # current_state_ids out of the DB and construct prev_state_ids.
- context._prev_state_id = input["prev_state_id"]
- context._event_type = input["event_type"]
- context._event_state_key = input["event_state_key"]
-
- context._current_state_ids = None
- context._prev_state_ids = None
- context._fetching_state_deferred = None
-
- context.state_group = input["state_group"]
- context.prev_group = input["prev_group"]
- context.delta_ids = _decode_state_dict(input["delta_ids"])
-
- context.rejected = input["rejected"]
- context.prev_state_events = input["prev_state_events"]
+ context = _AsyncEventContextImpl(
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ prev_state_id=input["prev_state_id"],
+ event_type=input["event_type"],
+ event_state_key=input["event_state_key"],
+ state_group=input["state_group"],
+ prev_group=input["prev_group"],
+ delta_ids=_decode_state_dict(input["delta_ids"]),
+ rejected=input["rejected"],
+ )
app_service_id = input["app_service_id"]
if app_service_id:
@@ -192,14 +183,7 @@ class EventContext(object):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._current_state_ids
@defer.inlineCallbacks
@@ -212,14 +196,7 @@ class EventContext(object):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
+ yield self._ensure_fetched(store)
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -233,6 +210,44 @@ class EventContext(object):
return self._current_state_ids
+ def _ensure_fetched(self, store):
+ return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+ """
+ An implementation of EventContext which fetches _current_state_ids and
+ _prev_state_ids from the database on demand.
+
+ Attributes:
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+
+ _event_state_key (str): The state_key of the event the context is
+ associated with.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ """
+
+ _prev_state_id = attr.ib(default=None)
+ _event_type = attr.ib(default=None)
+ _event_state_key = attr.ib(default=None)
+ _fetching_state_deferred = attr.ib(default=None)
+
+ def _ensure_fetched(self, store):
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(
+ self._fill_out_state, store
+ )
+
+ return make_deferred_yieldable(self._fetching_state_deferred)
+
@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
@@ -250,27 +265,6 @@ class EventContext(object):
else:
self._prev_state_ids = self._current_state_ids
- @defer.inlineCallbacks
- def update_state(
- self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
- ):
- """Replace the state in the context
- """
-
- # We need to make sure we wait for any ongoing fetching of state
- # to complete so that the updated state doesn't get clobbered
- if self._fetching_state_deferred:
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- self.state_group = state_group
- self._prev_state_ids = prev_state_ids
- self.prev_group = prev_group
- self._current_state_ids = current_state_ids
- self.delta_ids = delta_ids
-
- # We need to ensure that that we've marked as having fetched the state
- self._fetching_state_deferred = defer.succeed(None)
-
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 129771f18..5a907718d 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+
+from synapse.spam_checker_api import SpamCheckerApi
+
class SpamChecker(object):
def __init__(self, hs):
@@ -26,7 +31,14 @@ class SpamChecker(object):
pass
if module is not None:
- self.spam_checker = module(config=config)
+ # Older spam checkers don't accept the `api` argument, so we
+ # try and detect support.
+ spam_args = inspect.getfullargspec(module)
+ if "api" in spam_args.args:
+ api = SpamCheckerApi(hs)
+ self.spam_checker = module(config=config, api=api)
+ else:
+ self.spam_checker = module(config=config)
def check_event_for_spam(self, event):
"""Checks if a given event is considered "spammy" by this server.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 223aace0d..0e2218328 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -102,7 +102,7 @@ class FederationBase(object):
pass
if not res:
- logger.warn(
+ logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
@@ -173,7 +173,7 @@ class FederationBase(object):
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
- logger.warn(
+ logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
pdu.get_pdu_json(),
@@ -185,7 +185,7 @@ class FederationBase(object):
def errback(failure, pdu):
failure.trap(SynapseError)
with PreserveLoggingContext(ctx):
- logger.warn(
+ logger.warning(
"Signature check failed for %s: %s",
pdu.event_id,
failure.getErrorMessage(),
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 5b22a39b7..545d71965 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -196,7 +196,7 @@ class FederationClient(FederationBase):
dest, room_id, extremities, limit
)
- logger.debug("backfill transaction_data=%s", repr(transaction_data))
+ logger.debug("backfill transaction_data=%r", transaction_data)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
@@ -522,12 +522,12 @@ class FederationClient(FederationBase):
res = yield callback(destination)
return res
except InvalidResponseError as e:
- logger.warn("Failed to %s via %s: %s", description, destination, e)
+ logger.warning("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
- logger.warn(
+ logger.warning(
"Failed to %s via %s: %i %s",
description,
destination,
@@ -535,7 +535,9 @@ class FederationClient(FederationBase):
e.args[0],
)
except Exception:
- logger.warn("Failed to %s via %s", description, destination, exc_info=1)
+ logger.warning(
+ "Failed to %s via %s", description, destination, exc_info=1
+ )
raise SynapseError(502, "Failed to %s via any server" % (description,))
@@ -553,7 +555,7 @@ class FederationClient(FederationBase):
Note that this does not append any events to any graphs.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations (Iterable[str]): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5fc7c1d67..d942d77a7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -21,7 +21,6 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -86,14 +85,12 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_backfill_request(self, origin, room_id, versions, limit):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- pdus = yield self.handler.on_backfill_request(
+ pdus = await self.handler.on_backfill_request(
origin, room_id, versions, limit
)
@@ -101,9 +98,7 @@ class FederationServer(FederationBase):
return 200, res
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, origin, transaction_data):
+ async def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -118,18 +113,17 @@ class FederationServer(FederationBase):
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (
- yield self._transaction_linearizer.queue(
+ await self._transaction_linearizer.queue(
(origin, transaction.transaction_id)
)
):
- result = yield self._handle_incoming_transaction(
+ result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
- @defer.inlineCallbacks
- def _handle_incoming_transaction(self, origin, transaction, request_time):
+ async def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
@@ -140,7 +134,7 @@ class FederationServer(FederationBase):
Returns:
Deferred[(int, object)]: http response code and body
"""
- response = yield self.transaction_actions.have_responded(origin, transaction)
+ response = await self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
@@ -151,7 +145,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- # Reject if PDU count > 50 and EDU count > 100
+ # Reject if PDU count > 50 or EDU count > 100
if len(transaction.pdus) > 50 or (
hasattr(transaction, "edus") and len(transaction.edus) > 100
):
@@ -159,7 +153,7 @@ class FederationServer(FederationBase):
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
- yield self.transaction_actions.set_response(
+ await self.transaction_actions.set_response(
origin, transaction, 400, response
)
return 400, response
@@ -195,7 +189,7 @@ class FederationServer(FederationBase):
continue
try:
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
@@ -221,13 +215,12 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
- @defer.inlineCallbacks
- def process_pdus_for_room(room_id):
+ async def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
try:
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
- logger.warn("Ignoring PDUs for room %s from banned server", room_id)
+ logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@@ -237,10 +230,10 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
- yield self._handle_received_pdu(origin, pdu)
+ await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
- logger.warn("Error handling PDU %s: %s", event_id, e)
+ logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
@@ -251,36 +244,33 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- yield concurrently_execute(
+ await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
- yield self.received_edu(origin, edu.edu_type, edu.content)
+ await self.received_edu(origin, edu.edu_type, edu.content)
response = {"pdus": pdu_results}
logger.debug("Returning: %s", str(response))
- yield self.transaction_actions.set_response(origin, transaction, 200, response)
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
return 200, response
- @defer.inlineCallbacks
- def received_edu(self, origin, edu_type, content):
+ async def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
- yield self.registry.on_edu(edu_type, origin, content)
+ await self.registry.on_edu(edu_type, origin, content)
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
+ async def on_context_state_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -289,8 +279,8 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.wrap(
+ with (await self._server_linearizer.queue((origin, room_id))):
+ resp = await self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
room_id,
@@ -299,65 +289,60 @@ class FederationServer(FederationBase):
return 200, resp
- @defer.inlineCallbacks
- def on_state_ids_request(self, origin, room_id, event_id):
+ async def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+ state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
- @defer.inlineCallbacks
- def _on_context_state_request_compute(self, room_id, event_id):
- pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
- auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ async def _on_context_state_request_compute(self, room_id, event_id):
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self.handler.get_persisted_pdu(origin, event_id)
+ async def on_pdu_request(self, origin, event_id):
+ pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
return 200, self._transaction_from_pdus([pdu]).get_dict()
else:
return 404, ""
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
+ async def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
- resp = yield self.registry.on_query(query_type, args)
+ resp = await self.registry.on_query(query_type, args)
return 200, resp
- @defer.inlineCallbacks
- def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
if room_version not in supported_versions:
- logger.warn("Room version %s not in %s", room_version, supported_versions)
+ logger.warning(
+ "Room version %s not in %s", room_version, supported_versions
+ )
raise IncompatibleRoomVersionError(room_version=room_version)
- pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
+ pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content, room_version):
+ async def on_invite_request(self, origin, content, room_version):
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
400,
@@ -369,28 +354,27 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+ ret_pdu = await self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content, room_id):
+ async def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+ res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
return (
200,
@@ -402,48 +386,44 @@ class FederationServer(FederationBase):
},
)
- @defer.inlineCallbacks
- def on_make_leave_request(self, origin, room_id, user_id):
+ async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
- pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
+ await self.check_server_matches_acl(origin_host, room_id)
+ pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_send_leave_request(self, origin, content, room_id):
+ async def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
- yield self.handler.on_send_leave_request(origin, pdu)
+ await self.handler.on_send_leave_request(origin, pdu)
return 200, {}
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_event_auth(self, origin, room_id, event_id):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
+ auth_pdus = await self.handler.on_event_auth(event_id)
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
- @defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, room_id, event_id):
+ async def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
@@ -462,22 +442,22 @@ class FederationServer(FederationBase):
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
- with (yield self._server_linearizer.queue((origin, room_id))):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
]
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
+ signed_auth = await self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version
)
- ret = yield self.handler.on_query_auth(
+ ret = await self.handler.on_query_auth(
origin,
event_id,
room_id,
@@ -503,16 +483,14 @@ class FederationServer(FederationBase):
return self.on_query_request("user_devices", user_id)
@trace
- @defer.inlineCallbacks
- @log_function
- def on_claim_client_keys(self, origin, content):
+ async def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
- results = yield self.store.claim_e2e_one_time_keys(query)
+ results = await self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
@@ -536,14 +514,12 @@ class FederationServer(FederationBase):
return {"one_time_keys": json_result}
- @defer.inlineCallbacks
- @log_function
- def on_get_missing_events(
+ async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit
):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
@@ -553,7 +529,7 @@ class FederationServer(FederationBase):
limit,
)
- missing_events = yield self.handler.on_get_missing_events(
+ missing_events = await self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit
)
@@ -586,8 +562,7 @@ class FederationServer(FederationBase):
destination=None,
)
- @defer.inlineCallbacks
- def _handle_received_pdu(self, origin, pdu):
+ async def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
@@ -640,37 +615,34 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = yield self.store.get_room_version(pdu.room_id)
+ room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
- yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
+ await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "" % self.server_name
- @defer.inlineCallbacks
- def exchange_third_party_invite(
+ async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
- ret = yield self.handler.exchange_third_party_invite(
+ ret = await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
return ret
- @defer.inlineCallbacks
- def on_exchange_third_party_invite_request(self, room_id, event_dict):
- ret = yield self.handler.on_exchange_third_party_invite_request(
+ async def on_exchange_third_party_invite_request(self, room_id, event_dict):
+ ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
return ret
- @defer.inlineCallbacks
- def check_server_matches_acl(self, server_name, room_id):
+ async def check_server_matches_acl(self, server_name, room_id):
"""Check if the given server is allowed by the server ACLs in the room
Args:
@@ -680,13 +652,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
- state_ids = yield self.store.get_current_state_ids(room_id)
+ state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
- acl_event = yield self.store.get_event(acl_event_id)
+ acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return
@@ -709,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event):
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warn("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -723,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event):
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warn("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignorning non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -733,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event):
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warn("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignorning non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
@@ -747,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
- logger.warn(
+ logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
@@ -799,15 +771,14 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
- logger.warn("No handler registered for EDU type %s", edu_type)
+ logger.warning("No handler registered for EDU type %s", edu_type)
with start_active_span_from_edu(content, "handle_edu"):
try:
- yield handler(origin, content)
+ await handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
@@ -816,7 +787,7 @@ class FederationHandlerRegistry(object):
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
- logger.warn("No handler registered for query type %s", query_type)
+ logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
@@ -840,7 +811,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__()
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
@@ -848,17 +819,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
- return super(ReplicationFederationHandlerRegistry, self).on_edu(
+ return await super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content
)
- return self._send_edu(edu_type=edu_type, origin=origin, content=content)
+ return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
- def on_query(self, query_type, args):
+ async def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
- return handler(args)
+ return await handler(args)
- return self._get_query_client(query_type=query_type, args=args)
+ return await self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 454456a52..ced4925a9 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -36,6 +36,8 @@ from six import iteritems
from sortedcontainers import SortedDict
+from twisted.internet import defer
+
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
@@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object):
receipt (synapse.types.ReadReceipt):
"""
# nothing to do here: the replication listener will handle it.
- pass
+ return defer.succeed(None)
def send_presence(self, states):
"""As per FederationSender
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc75c3947..a5b36b182 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -192,15 +192,16 @@ class PerDestinationQueue(object):
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
- device_update_edus, dev_list_id = (
- yield self._get_device_update_edus(limit)
+ device_update_edus, dev_list_id = yield self._get_device_update_edus(
+ limit
)
limit -= len(device_update_edus)
- to_device_edus, device_stream_id = (
- yield self._get_to_device_message_edus(limit)
- )
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = yield self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus
@@ -359,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
- now_stream_id, results = yield self._store.get_devices_by_remote(
+ now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
- assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 5b6c79c51..67b3e1ab6 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -146,7 +146,7 @@ class TransactionManager(object):
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
- logger.warn(
+ logger.warning(
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
@@ -155,7 +155,7 @@ class TransactionManager(object):
)
else:
for p in pdus:
- logger.warn(
+ logger.warning(
"TX [%s] {%s} Failed to send event %s",
destination,
txn_id,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 7b1840814..920fa8685 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -122,10 +122,10 @@ class TransportLayerClient(object):
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
+ "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
destination,
room_id,
- repr(event_tuples),
+ event_tuples,
str(limit),
)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 0f16f21c2..d6c23f22b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes):
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
- logger.warn(
+ logger.warning(
"Error parsing auth header '%s': %s",
header_bytes.decode("ascii", "replace"),
e,
@@ -287,10 +287,12 @@ class BaseFederationServlet(object):
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.warn("authenticate_request failed: missing authentication")
+ logger.warning(
+ "authenticate_request failed: missing authentication"
+ )
raise
except Exception as e:
- logger.warn("authenticate_request failed: %s", e)
+ logger.warning("authenticate_request failed: %s", e)
raise
request_tags = {
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index dfd7ae041..d950a8b24 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -181,7 +181,7 @@ class GroupAttestionRenewer(object):
elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id)
else:
- logger.warn(
+ logger.warning(
"Incorrectly trying to do attestations for user: %r in %r",
user_id,
group_id,
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 8f10b6adb..29e8ffc29 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -488,7 +488,7 @@ class GroupsServerHandler(object):
profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
except Exception as e:
- logger.warn("Error getting profile for %s: %s", user_id, e)
+ logger.warning("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 38bc67191..2d7e6df6e 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -38,9 +38,10 @@ class AccountDataEventSource(object):
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
)
- account_data, room_account_data = (
- yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
- )
+ (
+ account_data,
+ room_account_data,
+ ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 1a87b5883..6407d56f8 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
@defer.inlineCallbacks
def get_whois(self, user):
connections = []
@@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after
- events = yield filter_events_for_client(self.store, user_id, events)
+ events = yield filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events)
@@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = yield self.store.get_state_for_event(event_id)
+ state = yield self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3e9b29815..fe62f78e6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
try:
limit = 100
while True:
- upper_bound, events = yield self.store.get_new_events_for_appservice(
+ (
+ upper_bound,
+ events,
+ ) = yield self.store.get_new_events_for_appservice(
self.current_max, limit
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 333eb3062..7a0f54ca2 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -525,7 +525,7 @@ class AuthHandler(BaseHandler):
result = None
if not user_infos:
- logger.warn("Attempted to login as %s but they do not exist", user_id)
+ logger.warning("Attempted to login as %s but they do not exist", user_id)
elif len(user_infos) == 1:
# a single match (possibly not exact)
result = user_infos.popitem()
@@ -534,7 +534,7 @@ class AuthHandler(BaseHandler):
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
- logger.warn(
+ logger.warning(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id,
@@ -728,7 +728,7 @@ class AuthHandler(BaseHandler):
result = yield self.validate_hash(password, password_hash)
if not result:
- logger.warn("Failed password login for user %s", user_id)
+ logger.warning("Failed password login for user %s", user_id)
return None
return user_id
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5f23ee448..26ef5e150 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
@trace
@@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
- prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -458,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -656,7 +668,7 @@ class DeviceListUpdater(object):
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
- logger.warn("Failed to handle device list update for %s", user_id)
+ logger.warning("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@@ -694,7 +706,7 @@ class DeviceListUpdater(object):
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
- logger.warn(
+ logger.warning(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0043cbea1..73b9e120f 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -52,7 +52,7 @@ class DeviceMessageHandler(object):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
- logger.warn(
+ logger.warning(
"Dropping device message from %r with spoofed sender %r",
origin,
sender_user_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 526379c6f..c4632f898 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
ignore_backoff=True,
)
except CodeMessageException as e:
- logging.warn("Error retrieving alias")
+ logging.warning("Error retrieving alias")
if e.code == 404:
result = None
else:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60b..f09a0b73c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@@ -119,9 +130,10 @@ class E2eKeysHandler(object):
else:
query_list.append((user_id, None))
- user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(query_list)
- )
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = yield self.store.get_user_devices_from_cache(query_list)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
for device_id, device in iteritems(devices):
@@ -207,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
@@ -251,7 +265,7 @@ class E2eKeysHandler(object):
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master|self_signing|user_signing) -> user_id -> key
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -343,7 +357,16 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- return {"device_keys": res}
+ ret = {"device_keys": res}
+
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
@trace
@defer.inlineCallbacks
@@ -688,17 +711,21 @@ class E2eKeysHandler(object):
try:
# get our self-signing key to verify the signatures
- _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "self_signing"
- )
+ (
+ _,
+ self_signing_key_id,
+ self_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
# that we can check if a signature that was sent is a signature of
# the master key or of a device
- master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "master"
- )
+ (
+ master_key,
+ _,
+ master_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
@@ -838,9 +865,11 @@ class E2eKeysHandler(object):
try:
# get our user-signing key to verify the signatures
- user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
- user_id, "user_signing"
- )
+ (
+ user_signing_key,
+ user_signing_key_id,
+ user_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -859,7 +888,11 @@ class E2eKeysHandler(object):
try:
# get the target user's master key, to make sure it matches
# what was sent
- master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
+ (
+ master_key,
+ master_key_id,
+ _,
+ ) = yield self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -1047,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ """Handles incoming signing key updates from federation and updates the DB"""
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(
+ self_signing_key
+ )
+ device_ids.append(verify_key.version)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 5e748687e..45fe13c62 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
+ def __init__(self, hs):
+ super(EventHandler, self).__init__(hs)
+ self.storage = hs.get_storage()
+
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
- self.store, user.to_string(), [event], is_peeking=is_peeking
+ self.storage, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 488058fe6..8cafcfdab 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
make_deferred_yieldable,
@@ -109,6 +110,8 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -180,7 +183,7 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warn(
+ logger.warning(
"[%s %s] Received event failed sanity checks", room_id, event_id
)
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
@@ -301,7 +304,7 @@ class FederationHandler(BaseHandler):
# following.
if sent_to_us_directly:
- logger.warn(
+ logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
room_id,
event_id,
@@ -324,7 +327,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.store.get_state_groups_ids(room_id, seen)
+ ours = yield self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -350,10 +353,11 @@ class FederationHandler(BaseHandler):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state, got_auth_chain = (
- yield self.federation_client.get_state_for_room(
- origin, room_id, p
- )
+ (
+ remote_state,
+ got_auth_chain,
+ ) = yield self.federation_client.get_state_for_room(
+ origin, room_id, p
)
# we want the state *after* p; get_state_for_room returns the
@@ -405,7 +409,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
- logger.warn(
+ logger.warning(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
room_id,
@@ -518,7 +522,9 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+ logger.warning(
+ "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
+ )
return
logger.info(
@@ -545,7 +551,7 @@ class FederationHandler(BaseHandler):
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
- logger.warn(
+ logger.warning(
"[%s %s] Received prev_event %s failed history check.",
room_id,
event_id,
@@ -888,7 +894,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
- self.store,
+ self.storage,
self.server_name,
list(extremities_events.values()),
redact=False,
@@ -1059,7 +1065,7 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster
"""
if len(ev.prev_event_ids()) > 20:
- logger.warn(
+ logger.warning(
"Rejecting event %s which has %i prev_events",
ev.event_id,
len(ev.prev_event_ids()),
@@ -1067,7 +1073,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
- logger.warn(
+ logger.warning(
"Rejecting event %s which has %i auth_events",
ev.event_id,
len(ev.auth_event_ids()),
@@ -1101,7 +1107,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the
- server `target_host`.
+ servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
event that we can fill out and sign. This is then sent to the
@@ -1110,6 +1116,15 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
+
+ Args:
+ target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+
+ room_id (str): The ID of the room to join.
+
+ joinee (str): The User ID of the joining user.
+
+ content (dict): The event content to use for the join event.
"""
logger.debug("Joining %s to %s", joinee, room_id)
@@ -1169,6 +1184,22 @@ class FederationHandler(BaseHandler):
yield self._persist_auth_tree(origin, auth_chain, state, event)
+ # Check whether this room is the result of an upgrade of a room we already know
+ # about. If so, migrate over user information
+ predecessor = yield self.store.get_room_predecessor(room_id)
+ if not predecessor:
+ return
+ old_room_id = predecessor["room_id"]
+ logger.debug(
+ "Found predecessor for %s during remote join: %s", room_id, old_room_id
+ )
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ yield member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, room_id
+ )
+
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
@@ -1203,7 +1234,7 @@ class FederationHandler(BaseHandler):
with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
- logger.warn(
+ logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
)
@@ -1250,7 +1281,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
except AuthError as e:
- logger.warn("Failed to create join %r because %s", event, e)
+ logger.warning("Failed to create join to %s because %s", room_id, e)
raise e
event_allowed = yield self.third_party_event_rules.check_event_allowed(
@@ -1494,7 +1525,7 @@ class FederationHandler(BaseHandler):
room_version, event, context, do_sig_check=False
)
except AuthError as e:
- logger.warn("Failed to create new leave %r because %s", event, e)
+ logger.warning("Failed to create new leave %r because %s", event, e)
raise e
return event
@@ -1549,7 +1580,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups(room_id, [event_id])
+ state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1578,7 +1609,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
+ state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1606,7 +1637,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
- events = yield filter_events_for_server(self.store, origin, events)
+ events = yield filter_events_for_server(self.storage, origin, events)
return events
@@ -1636,7 +1667,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(self.store, origin, [event])
+ events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -1788,7 +1819,7 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
- logger.warn("Rejecting %s because %s", e.event_id, err.msg)
+ logger.warning("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
@@ -1841,12 +1872,7 @@ class FederationHandler(BaseHandler):
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
- try:
- yield self.do_auth(origin, event, context, auth_events=auth_events)
- except AuthError as e:
- logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
-
- context.rejected = RejectedReason.AUTH_ERROR
+ context = yield self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled)
@@ -1902,7 +1928,7 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets = yield self.store.get_state_groups(
+ state_sets = yield self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
@@ -1938,7 +1964,7 @@ class FederationHandler(BaseHandler):
try:
event_auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
- logger.warn("Soft-failing %r because %s", event, e)
+ logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
@defer.inlineCallbacks
@@ -1993,7 +2019,7 @@ class FederationHandler(BaseHandler):
)
missing_events = yield filter_events_for_server(
- self.store, origin, missing_events
+ self.storage, origin, missing_events
)
return missing_events
@@ -2015,12 +2041,12 @@ class FederationHandler(BaseHandler):
Also NB that this function adds entries to it.
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context object
"""
room_version = yield self.store.get_room_version(event.room_id)
try:
- yield self._update_auth_events_and_context_for_auth(
+ context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
@@ -2037,8 +2063,10 @@ class FederationHandler(BaseHandler):
try:
event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ return context
@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
@@ -2062,7 +2090,7 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
- defer.Deferred[None]
+ defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
@@ -2101,7 +2129,7 @@ class FederationHandler(BaseHandler):
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
- return
+ return context
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
@@ -2142,7 +2170,7 @@ class FederationHandler(BaseHandler):
if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
- return
+ return context
# FIXME: Assumes we have and stored all the state for all the
# prev_events
@@ -2151,7 +2179,7 @@ class FederationHandler(BaseHandler):
)
if not different_auth:
- return
+ return context
logger.info(
"auth_events refers to events which are not in our calculated auth "
@@ -2198,10 +2226,12 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state)
- yield self._update_context_for_auth_events(
+ context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key
)
+ return context
+
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
@@ -2210,14 +2240,16 @@ class FederationHandler(BaseHandler):
Args:
event (Event): The event we're handling the context for
- context (synapse.events.snapshot.EventContext): event context
- to be updated
+ context (synapse.events.snapshot.EventContext): initial event context
auth_events (dict[(str, str)->str]): Events to update in the event
context.
event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.
+
+ Returns:
+ Deferred[EventContext]: new event context
"""
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2234,7 +2266,7 @@ class FederationHandler(BaseHandler):
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = yield self.store.store_state_group(
+ state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -2242,7 +2274,7 @@ class FederationHandler(BaseHandler):
current_state_ids=current_state_ids,
)
- yield context.update_state(
+ return EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
@@ -2431,10 +2463,12 @@ class FederationHandler(BaseHandler):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying new third party invite %r because %s", event, e)
+ logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
@@ -2487,7 +2521,7 @@ class FederationHandler(BaseHandler):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying third party invite %r because %s", event, e)
+ logger.warning("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
@@ -2495,6 +2529,7 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@@ -2664,7 +2699,7 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
else:
- max_stream_id = yield self.store.persist_events(
+ max_stream_id = yield self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 46eb9ee88..92fecbfc4 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -392,7 +392,7 @@ class GroupsLocalHandler(object):
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
- logger.warn("No profile for user %s: %s", user_id, e)
+ logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {}
return {"state": "invite", "user_profile": user_profile}
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index ba99ddf76..000fbf090 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler):
changed = False
if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
- logger.warn("Received %d response while unbinding threepid", e.code)
+ logger.warning("Received %d response while unbinding threepid", e.code)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(500, "Failed to contact identity server")
@@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
- logger.warn(
+ logger.warning(
'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for "
@@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler):
if self.hs.config.using_identity_server_from_trusted_list:
# Warn that a deprecated config option is in use
- logger.warn(
+ logger.warning(
'The config option "trust_identity_server_for_password_resets" '
'has been replaced by "account_threepid_delegate". '
"Please consult the sample config at docs/sample_config.yaml for "
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index f991efeee..81dce96f4 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
def snapshot_all_rooms(
self,
@@ -126,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
tags_by_room = yield self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(user_id)
+ account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ user_id
)
public_room_ids = yield self.store.get_public_room_ids()
@@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
- self.store.get_state_for_events, [event.event_id]
+ self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(self.store, user_id, messages)
+ messages = yield filter_events_for_client(
+ self.storage, user_id, messages
+ )
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
- room_state = yield self.store.get_state_for_events([member_event_id])
+ room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token)
@@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0f8cce8ff..d682dc2b7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -59,6 +59,8 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@@ -74,15 +76,16 @@ class MessageHandler(object):
Raises:
SynapseError if something went wrong.
"""
- membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -135,12 +138,12 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events
+ self.storage, user_id, last_events
)
event = last_events[0]
if visible_events:
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
@@ -151,9 +154,10 @@ class MessageHandler(object):
% (user_id, room_id, at_token),
)
else:
- membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(room_id, user_id)
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
@@ -161,7 +165,7 @@ class MessageHandler(object):
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
@@ -234,6 +238,7 @@ class EventCreationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -687,7 +692,7 @@ class EventCreationHandler(object):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
- logger.warn("Denying new event %r because %s", event, err)
+ logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
@@ -868,7 +873,7 @@ class EventCreationHandler(object):
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- (event_stream_id, max_stream_id) = yield self.store.persist_event(
+ event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event, context=context
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5744f4579..97f15a1c3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -69,6 +69,8 @@ class PaginationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
@@ -210,9 +212,10 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
- )
+ (
+ membership,
+ member_event_id,
+ ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -255,7 +258,7 @@ class PaginationHandler(object):
events = event_filter.filter(events)
events = yield filter_events_for_client(
- self.store, user_id, events, is_peeking=(member_event_id is None)
+ self.storage, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
@@ -274,7 +277,7 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@@ -295,10 +298,8 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = (
- yield self._event_serializer.serialize_events(
- state, time_now, as_client_event=as_client_event
- )
+ chunk["state"] = yield self._event_serializer.serialize_events(
+ state, time_now, as_client_event=as_client_event
)
return chunk
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 8690f69d4..22e0a04da 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler):
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e:
- logger.warn(
+ logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e)
)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 3e4d8c93a..e3b528d27 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def received_client_read_marker(self, room_id, user_id, event_id):
+ async def received_client_read_marker(self, room_id, user_id, event_id):
"""Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker.
@@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler):
the read marker has changed.
"""
- with (yield self.read_marker_linearizer.queue((room_id, user_id))):
- existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+ with await self.read_marker_linearizer.queue((room_id, user_id)):
+ existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read"
)
@@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler):
if existing_read_marker:
# Only update if the new marker is ahead in the stream
- should_update = yield self.store.is_event_after(
+ should_update = await self.store.is_event_after(
event_id, existing_read_marker["event_id"]
)
if should_update:
content = {"event_id": event_id}
- max_id = yield self.store.add_account_data_to_room(
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6854c751a..9283c039e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
- def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = []
@@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
)
)
- yield self._handle_new_receipts(receipts)
+ await self._handle_new_receipts(receipts)
- @defer.inlineCallbacks
- def _handle_new_receipts(self, receipts):
+ async def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
for receipt in receipts:
- res = yield self.store.insert_receipt(
+ res = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
+ await maybe_awaitable(
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
+ )
)
return True
- @defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+ async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
data={"ts": int(self.clock.time_msec())},
)
- is_new = yield self._handle_new_receipts([receipt])
+ is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
- yield self.federation.send_read_receipt(receipt)
-
- @defer.inlineCallbacks
- def get_receipts_for_room(self, room_id, to_key):
- """Gets all receipts for a room, upto the given key.
- """
- result = yield self.store.get_linearized_receipts_for_room(
- room_id, to_key=to_key
- )
-
- if not result:
- return []
-
- return result
+ await self.federation.send_read_receipt(receipt)
class ReceiptEventSource(object):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 53410f120..cff6b0d37 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = (
- yield room_member_handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_alias
)
room_id = room_id.to_string()
else:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2816bd8f8..e92b2eafd 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
+
return ret
@defer.inlineCallbacks
@@ -147,21 +148,22 @@ class RoomCreationHandler(BaseHandler):
# we create and auth the tombstone event before properly creating the new
# room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester,
- {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- },
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = yield self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
},
- token_id=requester.access_token_id,
- )
+ },
+ token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
@@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- # and finally, shut down the PLs in the old room, and update them in the new
+ # Copy over user push rules, tags and migrate room directory state
+ yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, new_room_id
+ )
+
+ # finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state
@@ -822,6 +829,8 @@ class RoomContextHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter):
@@ -848,7 +857,7 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
- self.store, user.to_string(), events, is_peeking=is_peeking
+ self.storage, user.to_string(), events, is_peeking=is_peeking
)
event = yield self.store.get_event(
@@ -890,7 +899,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.store.get_state_for_events(
+ state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
results["state"] = list(state[last_event_id].values())
@@ -922,7 +931,7 @@ class RoomEventSource(object):
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
- logger.warn("Stream has topological part!!!! %r", from_key)
+ logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = self.store.get_app_service_by_user_id(user.to_string())
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 380e2fad5..06d09c294 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,10 +203,6 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- # Copy over user state if we're joining an upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -455,11 +451,6 @@ class RoomMemberHandler(object):
requester, remote_room_hosts, room_id, target, content
)
- # Copy over user state if this is a join on an remote upgraded room
- yield self.copy_user_state_if_room_upgrade(
- room_id, requester.user.to_string()
- )
-
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
@@ -498,36 +489,72 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
- def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
- """Copy user-specific information when they join a new room if that new room is the
- result of a room upgrade
+ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ """Upon our server becoming aware of an upgraded room, either by upgrading a room
+ ourselves or joining one, we can transfer over information from the previous room.
+
+ Copies user state (tags/push rules) for every local user that was in the old room, as
+ well as migrating the room directory state.
Args:
- new_room_id (str): The ID of the room the user is joining
- user_id (str): The ID of the user
+ old_room_id (str): The ID of the old room
+
+ room_id (str): The ID of the new room
+
+ Returns:
+ Deferred
+ """
+ # Find all local users that were in the old room and copy over each user's state
+ users = yield self.store.get_users_in_room(old_room_id)
+ yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+ # Add new room to the room directory if the old room was there
+ # Remove old room from the room directory
+ old_room = yield self.store.get_room(old_room_id)
+ if old_room and old_room["is_public"]:
+ yield self.store.set_room_is_public(old_room_id, False)
+ yield self.store.set_room_is_public(room_id, True)
+
+ @defer.inlineCallbacks
+ def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ """Copy user-specific information when they join a new room when that new room is the
+ result of a room upgrade
+
+ Args:
+ old_room_id (str): The ID of upgraded room
+ new_room_id (str): The ID of the new room
+ user_ids (Iterable[str]): User IDs to copy state for
Returns:
Deferred
"""
- # Check if the new room is an upgraded room
- predecessor = yield self.store.get_room_predecessor(new_room_id)
- if not predecessor:
- return
logger.debug(
- "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ "Copying over room tags and push rules from %s to %s for users %s",
+ old_room_id,
new_room_id,
- predecessor,
+ user_ids,
)
- # It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], new_room_id, user_id
- )
- # Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], new_room_id, user_id
- )
+ for user_id in user_ids:
+ try:
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ old_room_id, new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ except Exception:
+ logger.exception(
+ "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+ "Skipping...",
+ old_room_id,
+ new_room_id,
+ user_id,
+ )
+ continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
@@ -759,22 +786,25 @@ class RoomMemberHandler(object):
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
- token, public_keys, fallback_public_key, display_name = (
- yield self.identity_handler.ask_id_server_for_third_party_invite(
- requester=requester,
- id_server=id_server,
- medium=medium,
- address=address,
- room_id=room_id,
- inviter_user_id=user.to_string(),
- room_alias=canonical_room_alias,
- room_avatar_url=room_avatar_url,
- room_join_rules=room_join_rules,
- room_name=room_name,
- inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url,
- id_access_token=id_access_token,
- )
+ (
+ token,
+ public_keys,
+ fallback_public_key,
+ display_name,
+ ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ requester=requester,
+ id_server=id_server,
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ inviter_user_id=user.to_string(),
+ room_alias=canonical_room_alias,
+ room_avatar_url=room_avatar_url,
+ room_join_rules=room_join_rules,
+ room_name=room_name,
+ inviter_display_name=inviter_display_name,
+ inviter_avatar_url=inviter_avatar_url,
+ id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd5e90bac..56ed262a1 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
)
res["events_before"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_before"]
+ self.storage, user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_after"]
+ self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
@@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.store.get_state_for_event(
+ state = yield self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -394,15 +396,11 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = (
- yield self._event_serializer.serialize_events(
- context["events_before"], time_now
- )
+ context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"], time_now
)
- context["events_after"] = (
- yield self._event_serializer.serialize_events(
- context["events_after"], time_now
- )
+ context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"], time_now
)
state_results = {}
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 26bc27669..7f7d56390 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
user_deltas = {}
# Then count deltas for total_events and total_event_bytes.
- room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
+ (
+ room_count,
+ user_count,
+ ) = yield self.store.get_changes_room_total_events_and_bytes(
self.pos, max_pos
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d99160e9d..b536d410e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -230,6 +230,8 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
@@ -417,7 +419,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
- self.store,
+ self.storage,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -470,7 +472,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
- self.store,
+ self.storage,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -509,7 +511,7 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
)
if event.is_state():
@@ -580,7 +582,7 @@ class SyncHandler(object):
return None
last_event = last_events[-1]
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -757,11 +759,11 @@ class SyncHandler(object):
if full_state:
if batch:
- current_state_ids = yield self.store.get_state_ids_for_event(
+ current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
@@ -781,7 +783,7 @@ class SyncHandler(object):
)
elif batch.limited:
if batch:
- state_at_timeline_start = yield self.store.get_state_ids_for_event(
+ state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
else:
@@ -810,7 +812,7 @@ class SyncHandler(object):
)
if batch:
- current_state_ids = yield self.store.get_state_ids_for_event(
+ current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
else:
@@ -841,7 +843,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@@ -1204,10 +1206,11 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- account_data, account_data_by_room = (
- yield self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_updated_account_data_for_user(
+ user_id, since_token.account_data_key
)
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
@@ -1219,9 +1222,10 @@ class SyncHandler(object):
sync_config.user
)
else:
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(sync_config.user.to_string())
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5aa..8363d887a 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
super().__init__(hs)
self._enabled = bool(hs.config.recaptcha_private_key)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key
diff --git a/synapse/http/client.py b/synapse/http/client.py
index cdf828a4f..d4c285445 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
@@ -183,7 +184,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -192,6 +201,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -236,11 +247,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
@@ -535,7 +548,7 @@ class SimpleHttpClient(object):
b"Content-Length" in resp_headers
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
@@ -543,7 +556,7 @@ class SimpleHttpClient(object):
)
if response.code > 299:
- logger.warn("Got %d when downloading %s" % (response.code, url))
+ logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 000000000..be7b2ceb8
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 3fe4ffb9e..021b233a7 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -148,7 +148,7 @@ class SrvResolver(object):
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
+ logger.warning(
"Failed to resolve %r, falling back to cache. %r", service_name, e
)
return list(cache_entry)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3f7c93ffc..691380abd 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
body = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
@@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object):
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
- logger.warn(
+ logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
@@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object):
break
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object):
raise
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object):
d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 000000000..332da02a8
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 46af27c8f..58f9cc61c 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -170,7 +170,7 @@ class RequestMetrics(object):
tag = context.tag
if context != self.start_context:
- logger.warn(
+ logger.warning(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2ccb210fd..943d12c90 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -454,7 +454,7 @@ def respond_with_json(
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
- logger.warn(
+ logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 274c1a6a8..e9a5e46ce 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False):
try:
content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
- logger.warn("Unable to decode UTF-8")
+ logger.warning("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e:
- logger.warn("Unable to parse JSON: %s", e)
+ logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
diff --git a/synapse/http/site.py b/synapse/http/site.py
index df5274c17..ff8184a3d 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -199,7 +199,7 @@ class SynapseRequest(Request):
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
- logger.warn(
+ logger.warning(
"Error processing request %r: %s %s", self, reason.type, reason.value
)
@@ -305,7 +305,7 @@ class SynapseRequest(Request):
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ logger.warning("Failed to stop metrics: %r", e)
class XForwardedForRequest(SynapseRequest):
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 3220e985a..334ddaf39 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
def parse_drain_configs(
- drains: dict
+ drains: dict,
) -> typing.Generator[DrainConfiguration, None, None]:
"""
Parse the drain configurations.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 370000e37..2c1fb9dda 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -294,7 +294,7 @@ class LoggingContext(object):
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
if self.previous_context != old_context:
- logger.warn(
+ logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4e091314e..af161a81d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -159,6 +159,7 @@ class Notifier(object):
self.room_to_user_streams = {}
self.hs = hs
+ self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
@@ -425,7 +426,10 @@ class Notifier(object):
if name == "room":
new_events = yield filter_events_for_client(
- self.store, user.to_string(), new_events, is_peeking=is_peeking
+ self.storage,
+ user.to_string(),
+ new_events,
+ is_peeking=is_peeking,
)
elif name == "presence":
now = self.clock.time_msec()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 22491f370..1ba7bcd4d 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = self._get_rules_for_room(room_id)
+ rules_for_room = yield self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
@@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object):
room_members = yield self.store.get_joined_users_from_context(event, context)
- (power_levels, sender_power_level) = (
- yield self._get_power_levels_and_sender_level(event, context)
- )
+ (
+ power_levels,
+ sender_power_level,
+ ) = yield self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 42e5b0c0a..8c818a86b 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -234,14 +234,12 @@ class EmailPusher(object):
return
self.last_stream_ordering = last_stream_ordering
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.email,
- self.user_id,
- last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 629958780..e994037be 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -64,6 +64,7 @@ class HttpPusher(object):
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"]
@@ -102,7 +103,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
@@ -210,14 +211,12 @@ class HttpPusher(object):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = (
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.last_stream_ordering,
- self.clock.time_msec(),
- )
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_id,
+ self.last_stream_ordering,
+ self.clock.time_msec(),
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
@@ -246,7 +245,7 @@ class HttpPusher(object):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
- logger.warn(
+ logger.warning(
"Giving up on a notification to user %s, " "pushkey %s",
self.user_id,
self.pushkey,
@@ -299,7 +298,7 @@ class HttpPusher(object):
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
- logger.warn(
+ logger.warning(
("Ignoring rejected pushkey %s because we" " didn't send it"),
pk,
)
@@ -329,7 +328,7 @@ class HttpPusher(object):
return d
ctx = yield push_tools.get_context_for_event(
- self.store, self.state_handler, event, self.user_id
+ self.storage, self.state_handler, event, self.user_id
)
d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5b16ab4ae..1d15a06a5 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -119,6 +119,7 @@ class Mailer(object):
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
+ self.storage = hs.get_storage()
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
@@ -389,7 +390,7 @@ class Mailer(object):
}
the_events = yield filter_events_for_client(
- self.store, user_id, results["events_before"]
+ self.storage, user_id, results["events_before"]
)
the_events.append(notif_event)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 5ed9147de..b1587183a 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object):
pattern = UserID.from_string(user_id).localpart
if not pattern:
- logger.warn("event_match condition with no pattern")
+ logger.warning("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
@@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False):
regex_cache[(glob, word_boundary)] = r
return r.search(value)
except re.error:
- logger.warn("Failed to parse glob to regex: %r", glob)
+ logger.warning("Failed to parse glob to regex: %r", glob)
return False
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a54051a72..de5c101a5 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
+from synapse.storage import Storage
@defer.inlineCallbacks
@@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
-def get_context_for_event(store, state_handler, ev, user_id):
+def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {}
- room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
+ room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = yield calculate_room_name(
- store, room_state_ids, user_id, fallback_to_single_member=False
+ storage.main, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
- sender_state_event = yield store.get_event(sender_state_event_id)
+ sender_state_event = yield storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 08e840fdc..0f6992202 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -103,9 +103,7 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
- last_stream_ordering = (
- yield self.store.get_latest_push_action_stream_ordering()
- )
+ last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
user_id=user_id,
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index aa7da1c54..5871feaaf 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -61,7 +61,6 @@ REQUIREMENTS = [
"bcrypt>=3.1.0",
"pillow>=4.3.0",
"sortedcontainers>=1.4.4",
- "psutil>=2.0.0",
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0..c8056b0c0 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -110,14 +110,14 @@ class ReplicationEndpoint(object):
return {}
@abc.abstractmethod
- def _handle_request(self, request, **kwargs):
+ async def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
- Deferred[dict]: A JSON serialisable dict to be used as response
- body of request.
+ tuple[int, dict]: HTTP status code and a JSON serialisable dict
+ to be used as response body of request.
"""
pass
@@ -180,7 +180,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
- logger.warn("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f1695595..9af4e7e17 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request):
+ async def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
- context = yield EventContext.deserialize(
- self.store, event_payload["context"]
- )
+ context = EventContext.deserialize(self.store, event_payload["context"])
event_and_contexts.append((event, context))
logger.info("Got %d events from federation", len(event_and_contexts))
- yield self.federation_handler.persist_events_and_notify(
+ await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
@@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
- @defer.inlineCallbacks
- def _handle_request(self, request, edu_type):
+ async def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
@@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin)
- result = yield self.registry.on_edu(edu_type, origin, edu_content)
+ result = await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result
@@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""
return {"args": args}
- @defer.inlineCallbacks
- def _handle_request(self, request, query_type):
+ async def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
@@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
logger.info("Got %r query", query_type)
- result = yield self.registry.on_query(query_type, args)
+ result = await self.registry.on_query(query_type, args)
return 200, result
@@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""
return {}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id):
- yield self.store.clean_room_for_join(room_id)
+ async def _handle_request(self, request, room_id):
+ await self.store.clean_room_for_join(room_id)
return 200, {}
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 786f5232b..798b9d3af 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest
)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477a..cc1f24974 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
@@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
- yield self.federation_handler.do_invite_join(
+ await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
@@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"remote_room_hosts": remote_room_hosts,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = yield self.federation_handler.do_remotely_reject_invite(
+ event = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id
)
ret = event.get_pdu_json()
@@ -148,9 +144,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(user_id, room_id)
+ await self.store.locally_reject_invite(user_id, room_id)
ret = {}
return 200, ret
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 38260256c..915cfb943 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"address": address,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- yield self.registration_handler.register_with_store(
+ await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
was_guest=content["was_guest"],
@@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
"""
return {"auth_result": auth_result, "access_token": access_token}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
auth_result = content["auth_result"]
access_token = content["access_token"]
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index adb9b2f7f..9bafd60b1 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request, event_id):
+ async def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
@@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
- context = yield EventContext.deserialize(self.store, content["context"])
+ context = EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- yield self.event_creation_handler.persist_and_notify_client_event(
+ await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d..456bc005a 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict
import six
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665a..de50748c3 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index a44ceb00e..fead78388 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
"""
import logging
+from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
FederationAckCommand,
InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +176,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5ffdf2675..afaf002fe 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -249,7 +250,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc..5f52264e8 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c25..9e45429d4 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_user_signature_changes_for_remotes
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 939418ee2..5c2a2eb59 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, stream_ordering
)
if not r:
- logger.warn(
+ logger.warning(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
ts,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8414af08c..24a0ce74f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet):
address = address.lower()
# Check for login providers that support 3pid login types
- canonical_user_id, callback_3pid = (
- yield self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = yield self.auth_handler.check_password_provider_3pid(
+ medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
@@ -221,7 +222,7 @@ class LoginRestServlet(RestServlet):
medium, address
)
if not user_id:
- logger.warn(
+ logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet):
def do_token_login(self, login_submission):
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
result = yield self._register_device_with_callback(user_id, login_submission)
@@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet):
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9c1d41421..86bbcc0ee 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
- info = yield self._room_creation_handler.create_room(
+ info = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_type, state_key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_type, state_key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
)
msg_handler = self.message_handler
- data = yield msg_handler.get_room_data(
+ data = await msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
@@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
if txn_id:
set_tag("txn_id", txn_id)
@@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.room_member_handler.update_membership(
+ event = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_type, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, event_type, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict = {
@@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_identifier, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_identifier, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
content = parse_json_object_from_request(request)
@@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
server = parse_string(request, "server", default=None)
try:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
@@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
+ data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
)
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
)
return 200, data
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
@@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
+ data = await handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
@@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
third_party_instance_id=third_party_instance_id,
)
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
@@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
+ async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
- events = yield handler.get_state_events(
+ events = await handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
@@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- users_with_profile = yield self.message_handler.get_joined_members(
+ users_with_profile = await self.message_handler.get_joined_members(
requester, room_id
)
@@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
@@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
as_client_event = False
else:
event_filter = None
- msgs = yield self.pagination_handler.get_messages(
+ msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
- events = yield self.message_handler.get_state_events(
+ events = await self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
@@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
- content = yield self.initial_sync_handler.room_initial_sync(
+ content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
return 200, content
@@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
- event = yield self.event_handler.get_event(
+ event = await self.event_handler.get_event(
requester.user, room_id, event_id
)
except AuthError:
@@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
+ event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = yield self.room_context_handler.get_event_context(
+ results = await self.room_context_handler.get_event_context(
requester.user, room_id, event_id, limit, event_filter
)
@@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = yield self._event_serializer.serialize_events(
+ results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now
)
- results["event"] = yield self._event_serializer.serialize_event(
+ results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now
)
- results["events_after"] = yield self._event_serializer.serialize_events(
+ results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now
)
- results["state"] = yield self._event_serializer.serialize_events(
+ results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
)
@@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ async def on_POST(self, request, room_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
- yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
+ await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
@@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, membership_action, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
@@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.room_member_handler.do_3pid_invite(
+ await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content and membership_action in ["kick", "ban"]:
event_content = {"reason": content["reason"]}
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
- yield self.typing_handler.started_typing(
+ await self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=timeout,
)
else:
- yield self.typing_handler.stopped_typing(
+ await self.typing_handler.stopped_typing(
target_user=target_user, auth_user=requester.user, room_id=room_id
)
@@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = yield self.handlers.search_handler.search(
+ results = await self.handlers.search_handler.search(
requester.user, content, batch
)
@@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
+ room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 80cf7126a..f26eae794 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User password resets have been disabled due to lack of email config"
)
raise SynapseError(
@@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_password_reset_template_failure_html],
)
@@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Password reset emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
+ logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
@@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_add_threepid_template_failure_html],
)
@@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Adding emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index b3bf8567e..67cbc3731 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
if read_event_id:
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
user_id=requester.user.to_string(),
@@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
read_marker_event_id = body.get("m.fully_read", None)
if read_marker_event_id:
- yield self.read_marker_handler.received_client_read_marker(
+ await self.read_marker_handler.received_client_read_marker(
room_id,
user_id=requester.user.to_string(),
event_id=read_marker_event_id,
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 0dab03d22..92555bd4a 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
@@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, receipt_type, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, receipt_type, event_id):
+ requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 4f24a124a..91db92381 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Email registration has been disabled due to lack of email config"
)
raise SynapseError(
@@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
+ logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
@@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self.failure_email_template, = load_jinja2_templates(
+ (self.failure_email_template,) = load_jinja2_templates(
self.config.email_template_dir,
[self.config.email_registration_template_failure_html],
)
@@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User registration via email has been disabled due to lack of email config"
)
raise SynapseError(
@@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet):
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
- logger.warn("Ignoring initial_device_display_name without password")
+ logger.warning("Ignoring initial_device_display_name without password")
del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a883c8add..ccd8b17b2 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -112,9 +112,14 @@ class SyncRestServlet(RestServlet):
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
- "/sync: user=%r, timeout=%r, since=%r,"
- " set_presence=%r, filter_id=%r, device_id=%r"
- % (user, timeout, since, set_presence, filter_id, device_id)
+ "/sync: user=%r, timeout=%r, since=%r, "
+ "set_presence=%r, filter_id=%r, device_id=%r",
+ user,
+ timeout,
+ since,
+ set_presence,
+ filter_id,
+ device_id,
)
request_key = (user, timeout, since, filter_id, full_state, device_id)
@@ -389,7 +394,7 @@ class SyncRestServlet(RestServlet):
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
- logger.warn(
+ logger.warning(
"Event %r is under room %r instead of %r",
event.event_id,
room.room_id,
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 1044ae7b4..bb30ce3f3 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
"m.require_identity_server": False,
# as per MSC2290
"m.separate_add_and_bind": True,
+ # Implements support for label-based filtering as described in
+ # MSC2326.
+ "org.matrix.label_based_filtering": True,
},
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 55580bc59..e7fc3f043 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource):
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
- server, = request.postpath
+ (server,) = request.postpath
query = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index b972e152a..bd9186fe5 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -363,7 +363,7 @@ class MediaRepository(object):
},
)
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
@@ -372,7 +372,7 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
- logger.warn(
+ logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
@@ -383,10 +383,12 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
- logger.warn("Failed to fetch remote media %s/%s", server_name, media_id)
+ logger.warning(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise
except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
+ logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
@@ -691,7 +693,7 @@ class MediaRepository(object):
try:
os.remove(full_path)
except OSError as e:
- logger.warn("Failed to remove file: %r", full_path)
+ logger.warning("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 4d4b3c146..69544b371 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -77,6 +77,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -120,8 +122,10 @@ class PreviewUrlResource(DirectServeResource):
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
- ("Matching attrib '%s' with value '%s' against" " pattern '%s'")
- % (attrib, value, pattern)
+ "Matching attrib '%s' with value '%s' against" " pattern '%s'",
+ attrib,
+ value,
+ pattern,
)
if value is None:
@@ -137,7 +141,7 @@ class PreviewUrlResource(DirectServeResource):
match = False
continue
if match:
- logger.warn("URL %s blocked by url_blacklist entry %s", url, entry)
+ logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
@@ -189,7 +193,7 @@ class PreviewUrlResource(DirectServeResource):
media_info = yield self._download_url(url, user)
- logger.debug("got media_info of '%s'" % media_info)
+ logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
@@ -209,7 +213,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % url)
+ logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media
elif _is_html(media_info["media_type"]):
@@ -257,7 +261,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % og["og:image"])
+ logger.warning("Couldn't get dims for %s", og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name,
@@ -268,7 +272,7 @@ class PreviewUrlResource(DirectServeResource):
else:
del og["og:image"]
else:
- logger.warn("Failed to find any OG data in %s", url)
+ logger.warning("Failed to find any OG data in %s", url)
og = {}
# filter out any stupidly long values
@@ -282,7 +286,7 @@ class PreviewUrlResource(DirectServeResource):
for k in keys_to_remove:
del og[k]
- logger.debug("Calculated OG for %s as %s" % (url, og))
+ logger.debug("Calculated OG for %s as %s", url, og)
jsonog = json.dumps(og)
@@ -311,7 +315,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- logger.debug("Trying to get url '%s'" % url)
+ logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
@@ -331,7 +335,7 @@ class PreviewUrlResource(DirectServeResource):
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
- logger.warn("Error downloading %s: %r", url, e)
+ logger.warning("Error downloading %s: %r", url, e)
raise SynapseError(
500,
@@ -412,7 +416,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -444,7 +448,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
try:
@@ -460,7 +464,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 08329884a..931ce79be 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
@@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index 1fcc7375d..90c3b072e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -95,6 +96,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
+from synapse.storage import DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -167,6 +169,7 @@ class HomeServer(object):
"filtering",
"http_client_context_factory",
"simple_http_client",
+ "proxied_http_client",
"media_repository",
"media_repository_resource",
"federation_transport_client",
@@ -196,6 +199,7 @@ class HomeServer(object):
"account_validity_handler",
"saml_handler",
"event_client_serializer",
+ "storage",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -217,6 +221,7 @@ class HomeServer(object):
self.hostname = hostname
self._building = {}
self._listening_services = []
+ self.start_time = None
self.clock = Clock(reactor)
self.distributor = Distributor()
@@ -224,7 +229,7 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
- self.datastore = None
+ self.datastores = None
# Other kwargs are explicit dependencies
for depname in kwargs:
@@ -233,8 +238,10 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
with self.get_db_conn() as conn:
- self.datastore = self.DATASTORE_CLASS(conn, self)
+ datastore = self.DATASTORE_CLASS(conn, self)
+ self.datastores = DataStores(datastore, conn, self)
conn.commit()
+ self.start_time = int(self.get_clock().time())
logger.info("Finished setting up.")
def setup_master(self):
@@ -266,7 +273,7 @@ class HomeServer(object):
return self.clock
def get_datastore(self):
- return self.datastore
+ return self.datastores.main
def get_config(self):
return self.config
@@ -308,6 +315,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
@@ -537,6 +551,9 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_storage(self) -> Storage:
+ return Storage(self, self.datastores)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 16f8f6b57..b5e0b5709 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,8 +39,16 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
- self
+ self,
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
@@ -47,32 +56,32 @@ class HomeServer(object):
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
def get_event_creation_handler(
- self
+ self,
) -> synapse.handlers.message.EventCreationHandler:
pass
def get_set_password_handler(
- self
+ self,
) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
def get_federation_transport_client(
- self
+ self,
) -> synapse.federation.transport.client.TransportLayerClient:
pass
def get_media_repository_resource(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(
- self
+ self,
) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
def get_server_notices_manager(
- self
+ self,
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
def get_server_notices_sender(
- self
+ self,
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index c0e7f475c..9fae2e0af 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object):
room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id)
if not room_id:
- logger.warn("Failed to get server notices room")
+ logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
new file mode 100644
index 000000000..efcc10f80
--- /dev/null
+++ b/synapse/spam_checker_api/__init__.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+class SpamCheckerApi(object):
+ """A proxy object that gets passed to spam checkers so they can get
+ access to rooms and other relevant information.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+
+ self._store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_state_events_in_room(self, room_id, types):
+ """Gets state events for the given room.
+
+ Args:
+ room_id (string): The room ID to get state events in.
+ types (tuple): The event type and state key (using None
+ to represent 'any') of the room state to acquire.
+
+ Returns:
+ twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
+ The filtered state events in the room.
+ """
+ state_ids = yield self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
+ state = yield self._store.get_events(state_ids.values())
+ return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index dc9f5a900..2c04ab185 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -103,6 +103,7 @@ class StateHandler(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -231,6 +232,9 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
+
+ # FIXME: why do we populate current_state_ids? I thought the point was
+ # that we weren't supposed to have any state for outliers?
if old_state:
prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
@@ -271,7 +275,7 @@ class StateHandler(object):
else:
current_state_ids = prev_state_ids
- state_group = yield self.store.store_state_group(
+ state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
@@ -321,7 +325,7 @@ class StateHandler(object):
delta_ids = dict(entry.delta_ids)
delta_ids[key] = event.event_id
- state_group = yield self.store.store_state_group(
+ state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -334,7 +338,7 @@ class StateHandler(object):
delta_ids = entry.delta_ids
if entry.state_group is None:
- entry.state_group = yield self.store.store_state_group(
+ entry.state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=entry.prev_group,
@@ -376,14 +380,16 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
+ state_groups_ids = yield self.state_store.get_state_groups_ids(
+ room_id, event_ids
+ )
if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+ prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a249ecd21..0a1a8cc1e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,7 +27,26 @@ data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from synapse.storage.data_stores.main import DataStore # noqa: F401
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+ """The high level interfaces for talking to various storage layers.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.persistence = EventsPersistenceStorage(hs, stores)
+ self.state = StateGroupStorage(hs, stores)
def are_all_users_on_domain(txn, database_engine, domain):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f5906fcd5..1a2b7ebe2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -494,7 +494,7 @@ class SQLBaseStore(object):
exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn("Starting db txn '%s' from sentinel context", desc)
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
result = yield self.runWithConnection(
@@ -532,7 +532,7 @@ class SQLBaseStore(object):
"""
parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel:
- logger.warn(
+ logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
@@ -719,7 +719,7 @@ class SQLBaseStore(object):
raise
# presumably we raced with another transaction: let's retry.
- logger.warn(
+ logger.warning(
"IntegrityError when upserting into %s; retrying: %s", table, e
)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 80b57a948..37d469ffd 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -94,13 +94,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = False
def start_doing_background_updates(self):
- run_as_background_process("background_updates", self._run_background_updates)
+ run_as_background_process("background_updates", self.run_background_updates)
@defer.inlineCallbacks
- def _run_background_updates(self):
+ def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates")
while True:
- yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+ if sleep:
+ yield self.hs.get_clock().sleep(
+ self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0
+ )
try:
result = yield self.do_next_background_update(
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 56094078e..cb184a98c 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -12,3 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+
+class DataStores(object):
+ """The various data stores.
+
+ These are low level interfaces to physical databases.
+ """
+
+ def __init__(self, main_store, db_conn, hs):
+ # Note we pass in the main store here as workers use a different main
+ # store.
+ self.main = main_store
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index b185ba0b3..10c940df1 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -139,7 +139,10 @@ class DataStore(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
@@ -317,7 +320,7 @@ class DataStore(
) u
"""
txn.execute(sql, (time_from,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
def count_r30_users(self):
@@ -396,7 +399,7 @@ class DataStore(
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
results["all"] = count
return results
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a354234..71f62036c 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
- update_context = update[3]
- update_stream_id = update[2]
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- previous_update_stream_id, _ = query_map.get(key, (0, None))
-
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
return now_stream_id, results
- def _get_devices_by_remote_txn(
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index ef88e7929..1cbbae5b6 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _delete_e2e_room_keys_version_txn(txn):
if version is None:
this_version = self._get_current_version(txn, user_id)
+ if this_version is None:
+ raise StoreError(404, "No current backup version")
else:
this_version = version
+ self._simple_delete_txn(
+ txn,
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": this_version},
+ )
+
return self._simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index a0bc6f2d1..073412a78 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self._execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index a470a48e0..90bef0cd2 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -364,9 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
- logger.debug(
- "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
- )
+ logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
event_results = set()
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 22025effb..04ce21ac6 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
- offset_stream_ordering, = stream_row
+ (offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 03b5111c5..301f8ea12 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,28 +17,26 @@
import itertools
import logging
-from collections import Counter as c_counter, OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, namedtuple
from functools import wraps
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event_dict
-from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.event_federation import EventFederationStore
@@ -46,10 +44,8 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -60,37 +56,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
- "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
- "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
-# The number of forward extremities for each new event.
-forward_extremities_counter = Histogram(
- "synapse_storage_events_forward_extremities_persisted",
- "Number of forward extremities for each new event",
- buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
-# The number of stale forward extremities for each new event. Stale extremities
-# are those that were in the previous set of extremities as well as the new.
-stale_forward_extremities_counter = Histogram(
- "synapse_storage_events_stale_forward_extremities_persisted",
- "Number of unchanged forward extremities for each new event",
- buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
-)
-
def encode_json(json_object):
"""
@@ -102,110 +67,6 @@ def encode_json(json_object):
return out
-class _EventPeristenceQueue(object):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
- """
-
- _EventPersistQueueItem = namedtuple(
- "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
- )
-
- def __init__(self):
- self._event_persist_queues = {}
- self._currently_persisting_rooms = set()
-
- def add_to_queue(self, room_id, events_and_contexts, backfilled):
- """Add events to the queue, with the given persist_event options.
-
- NB: due to the normal usage pattern of this method, it does *not*
- follow the synapse logcontext rules, and leaves the logcontext in
- place whether or not the returned deferred is ready.
-
- Args:
- room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
-
- Returns:
- defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext.
- """
- queue = self._event_persist_queues.setdefault(room_id, deque())
- if queue:
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- end_item = queue[-1]
- if end_item.backfilled == backfilled:
- end_item.events_and_contexts.extend(events_and_contexts)
- return end_item.deferred.observe()
-
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
- queue.append(
- self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
- deferred=deferred,
- )
- )
-
- return deferred.observe()
-
- def handle_queue(self, room_id, per_item_callback):
- """Attempts to handle the queue for a room if not already being handled.
-
- The given callback will be invoked with for each item in the queue,
- of type _EventPersistQueueItem. The per_item_callback will continuously
- be called with new items, unless the queue becomnes empty. The return
- value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferreds as well.
-
- This function should therefore be called whenever anything is added
- to the queue.
-
- If another callback is currently handling the queue then it will not be
- invoked.
- """
-
- if room_id in self._currently_persisting_rooms:
- return
-
- self._currently_persisting_rooms.add(room_id)
-
- @defer.inlineCallbacks
- def handle_queue_loop():
- try:
- queue = self._get_drainining_queue(room_id)
- for item in queue:
- try:
- ret = yield per_item_callback(item)
- except Exception:
- with PreserveLoggingContext():
- item.deferred.errback()
- else:
- with PreserveLoggingContext():
- item.deferred.callback(ret)
- finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
- self._currently_persisting_rooms.discard(room_id)
-
- # set handle_queue_loop off in the background
- run_as_background_process("persist_events", handle_queue_loop)
-
- def _get_drainining_queue(self, room_id):
- queue = self._event_persist_queues.setdefault(room_id, deque())
-
- try:
- while True:
- yield queue.popleft()
- except IndexError:
- # Queue has been drained.
- pass
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -221,7 +82,7 @@ def _retry_on_integrity_error(func):
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
- res = yield func(self, *args, **kwargs)
+ res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
@@ -241,9 +102,6 @@ class EventsStore(
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self._event_persist_queue = _EventPeristenceQueue()
- self._state_resolution_handler = hs.get_state_resolution_handler()
-
# Collect metrics on the number of forward extremities that exist.
# Counter of number of extremities to count
self._current_forward_extremities_amount = c_counter()
@@ -286,340 +144,106 @@ class EventsStore(
res = yield self.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
- @defer.inlineCallbacks
- def persist_events(self, events_and_contexts, backfilled=False):
- """
- Write events to the database
- Args:
- events_and_contexts: list of tuples of (event, context)
- backfilled (bool): Whether the results are retrieved from federation
- via backfill or not. Used to determine if they're "new" events
- which might update the current state etc.
-
- Returns:
- Deferred[int]: the stream ordering of the latest persisted event
- """
- partitioned = {}
- for event, ctx in events_and_contexts:
- partitioned.setdefault(event.room_id, []).append((event, ctx))
-
- deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
- d = self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
- )
- deferreds.append(d)
-
- for room_id in partitioned:
- self._maybe_start_persisting(room_id)
-
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
-
- return max_persisted_id
-
- @defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False):
- """
-
- Args:
- event (EventBase):
- context (EventContext):
- backfilled (bool):
-
- Returns:
- Deferred: resolves to (int, int): the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
- """
- deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
- )
-
- self._maybe_start_persisting(event.room_id)
-
- yield make_deferred_yieldable(deferred)
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
-
- def _maybe_start_persisting(self, room_id):
- @defer.inlineCallbacks
- def persisting_queue(item):
- with Measure(self._clock, "persist_events"):
- yield self._persist_events(
- item.events_and_contexts, backfilled=item.backfilled
- )
-
- self._event_persist_queue.handle_queue(room_id, persisting_queue)
-
@_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(
- self, events_and_contexts, backfilled=False, delete_existing=False
+ def _persist_events_and_state_updates(
+ self,
+ events_and_contexts,
+ current_state_for_room,
+ state_delta_for_room,
+ new_forward_extremeties,
+ backfilled=False,
+ delete_existing=False,
):
- """Persist events to db
+ """Persist a set of events alongside updates to the current state and
+ forward extremities tables.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
+ current_state_for_room (dict[str, dict]): Map from room_id to the
+ current state of the room based on forward extremities
+ state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
+ of `(to_delete, to_insert)` where to_delete is a list
+ of type/state keys to remove from current state, and to_insert
+ is a map (type,key)->event_id giving the state delta in each
+ room.
+ new_forward_extremities (dict[str, list[str]]): Map from room_id
+ to list of event IDs that are the new forward extremities of
+ the room.
+ backfilled (bool)
delete_existing (bool):
Returns:
Deferred: resolves when the events have been persisted
"""
- if not events_and_contexts:
- return
- chunks = [
- events_and_contexts[x : x + 100]
- for x in range(0, len(events_and_contexts), 100)
- ]
+ # We want to calculate the stream orderings as late as possible, as
+ # we only notify after all events with a lesser stream ordering have
+ # been persisted. I.e. if we spend 10s inside the with block then
+ # that will delay all subsequent events from being notified about.
+ # Hence why we do it down here rather than wrapping the entire
+ # function.
+ #
+ # Its safe to do this after calculating the state deltas etc as we
+ # only need to protect the *persistence* of the events. This is to
+ # ensure that queries of the form "fetch events since X" don't
+ # return events and stream positions after events that are still in
+ # flight, as otherwise subsequent requests "fetch event since Y"
+ # will not return those events.
+ #
+ # Note: Multiple instances of this function cannot be in flight at
+ # the same time for the same room.
+ if backfilled:
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ len(events_and_contexts)
+ )
+ else:
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ len(events_and_contexts)
+ )
- for chunk in chunks:
- # We can't easily parallelize these since different chunks
- # might contain the same event. :(
+ with stream_ordering_manager as stream_orderings:
+ for (event, context), stream in zip(events_and_contexts, stream_orderings):
+ event.internal_metadata.stream_ordering = stream
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->list[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremeties = {}
-
- # map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events.
- # This is simply used to prefill the get_current_state_ids
- # cache
- current_state_for_room = {}
-
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room = {}
+ yield self.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ )
+ persist_event_counter.inc(len(events_and_contexts))
if not backfilled:
- with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
- # We do this by working out what the new extremities are and then
- # calculating the state from that.
- events_by_room = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in iteritems(events_by_room):
- latest_event_ids = yield self.get_latest_event_ids_in_room(
- room_id
- )
- new_latest_event_ids = yield self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- latest_event_ids = set(latest_event_ids)
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremeties[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.info("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = yield self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids = res
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- state_delta_for_room[room_id] = ([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = yield self._calculate_state_delta(
- room_id, current_state
- )
- state_delta_for_room[room_id] = delta
-
- # If we have the current_state then lets prefill
- # the cache with it.
- if current_state is not None:
- current_state_for_room[room_id] = current_state
-
- # We want to calculate the stream orderings as late as possible, as
- # we only notify after all events with a lesser stream ordering have
- # been persisted. I.e. if we spend 10s inside the with block then
- # that will delay all subsequent events from being notified about.
- # Hence why we do it down here rather than wrapping the entire
- # function.
- #
- # Its safe to do this after calculating the state deltas etc as we
- # only need to protect the *persistence* of the events. This is to
- # ensure that queries of the form "fetch events since X" don't
- # return events and stream positions after events that are still in
- # flight, as otherwise subsequent requests "fetch event since Y"
- # will not return those events.
- #
- # Note: Multiple instances of this function cannot be in flight at
- # the same time for the same room.
- if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
- len(chunk)
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ events_and_contexts[-1][0].internal_metadata.stream_ordering
)
- else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
- with stream_ordering_manager as stream_orderings:
- for (event, context), stream in zip(chunk, stream_orderings):
- event.internal_metadata.stream_ordering = stream
+ for event, context in events_and_contexts:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
- yield self.runInteraction(
- "persist_events",
- self._persist_events_txn,
- events_and_contexts=chunk,
- backfilled=backfilled,
- delete_existing=delete_existing,
- state_delta_for_room=state_delta_for_room,
- new_forward_extremeties=new_forward_extremeties,
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
+
+ for room_id, new_state in iteritems(current_state_for_room):
+ self.get_current_state_ids.prefill((room_id,), new_state)
+
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
)
- persist_event_counter.inc(len(chunk))
-
- if not backfilled:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering
- )
-
- for event, context in chunk:
- if context.app_service:
- origin_type = "local"
- origin_entity = context.app_service.id
- elif self.hs.is_mine_id(event.sender):
- origin_type = "local"
- origin_entity = "*client*"
- else:
- origin_type = "remote"
- origin_entity = get_domain_from_id(event.sender)
-
- event_counter.labels(event.type, origin_type, origin_entity).inc()
-
- for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
-
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
- (room_id,), list(latest_event_ids)
- )
-
- @defer.inlineCallbacks
- def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremities for a room given events to
- persist.
-
- Assumes that we are only persisting events for one room at a time.
- """
-
- # we're only interested in new events which aren't outliers and which aren't
- # being rejected.
- new_events = [
- event
- for event, ctx in event_contexts
- if not event.internal_metadata.is_outlier()
- and not ctx.rejected
- and not event.internal_metadata.is_soft_failed()
- ]
-
- latest_event_ids = set(latest_event_ids)
-
- # start with the existing forward extremities
- result = set(latest_event_ids)
-
- # add all the new events to the list
- result.update(event.event_id for event in new_events)
-
- # Now remove all events which are prev_events of any of the new events
- result.difference_update(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
-
- # Remove any events which are prev_events of any existing events.
- existing_prevs = yield self._get_events_which_are_prevs(result)
- result.difference_update(existing_prevs)
-
- # Finally handle the case where the new events have soft-failed prev
- # events. If they do we need to remove them and their prev events,
- # otherwise we end up with dangling extremities.
- existing_prevs = yield self._get_prevs_before_rejected(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
- result.difference_update(existing_prevs)
-
- # We only update metrics for events that change forward extremities
- # (e.g. we ignore backfill/outliers/etc)
- if result != latest_event_ids:
- forward_extremities_counter.observe(len(result))
- stale = latest_event_ids & result
- stale_forward_extremities_counter.observe(len(stale))
-
- return result
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -725,188 +349,6 @@ class EventsStore(
return existing_prevs
- @defer.inlineCallbacks
- def _get_new_state_after_events(
- self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
- ):
- """Calculate the current state dict after adding some new events to
- a room
-
- Args:
- room_id (str):
- room to which the events are being added. Used for logging etc
-
- events_context (list[(EventBase, EventContext)]):
- events and contexts which are being added to the room
-
- old_latest_event_ids (iterable[str]):
- the old forward extremities for the room.
-
- new_latest_event_ids (iterable[str]):
- the new forward extremities for the room.
-
- Returns:
- Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
-
- If there has been a change then we only return the delta if its
- already been calculated. Conversely if we do know the delta then
- the new current state is only returned if we've already calculated
- it.
- """
- # map from state_group to ((type, key) -> event_id) state map
- state_groups_map = {}
-
- # Map from (prev state group, new state group) -> delta state dict
- state_group_deltas = {}
-
- for ev, ctx in events_context:
- if ctx.state_group is None:
- # This should only happen for outlier events.
- if not ev.internal_metadata.is_outlier():
- raise Exception(
- "Context for new event %s has no state "
- "group" % (ev.event_id,)
- )
- continue
-
- if ctx.state_group in state_groups_map:
- continue
-
- # We're only interested in pulling out state that has already
- # been cached in the context. We'll pull stuff out of the DB later
- # if necessary.
- current_state_ids = ctx.get_cached_current_state_ids()
- if current_state_ids is not None:
- state_groups_map[ctx.state_group] = current_state_ids
-
- if ctx.prev_group:
- state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
- # We need to map the event_ids to their state groups. First, let's
- # check if the event is one we're persisting, in which case we can
- # pull the state group from its context.
- # Otherwise we need to pull the state group from the database.
-
- # Set of events we need to fetch groups for. (We know none of the old
- # extremities are going to be in events_context).
- missing_event_ids = set(old_latest_event_ids)
-
- event_id_to_state_group = {}
- for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding.
- for ev, ctx in events_context:
- if event_id == ev.event_id and ctx.state_group is not None:
- event_id_to_state_group[event_id] = ctx.state_group
- break
- else:
- # If we couldn't find it, then we'll need to pull
- # the state from the database
- missing_event_ids.add(event_id)
-
- if missing_event_ids:
- # Now pull out the state groups for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
- event_id_to_state_group.update(event_to_groups)
-
- # State groups of old_latest_event_ids
- old_state_groups = set(
- event_id_to_state_group[evid] for evid in old_latest_event_ids
- )
-
- # State groups of new_latest_event_ids
- new_state_groups = set(
- event_id_to_state_group[evid] for evid in new_latest_event_ids
- )
-
- # If they old and new groups are the same then we don't need to do
- # anything.
- if old_state_groups == new_state_groups:
- return None, None
-
- if len(new_state_groups) == 1 and len(old_state_groups) == 1:
- # If we're going from one state group to another, lets check if
- # we have a delta for that transition. If we do then we can just
- # return that.
-
- new_state_group = next(iter(new_state_groups))
- old_state_group = next(iter(old_state_groups))
-
- delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
- if delta_ids is not None:
- # We have a delta from the existing to new current state,
- # so lets just return that. If we happen to already have
- # the current state in memory then lets also return that,
- # but it doesn't matter if we don't.
- new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
-
- # Now that we have calculated new_state_groups we need to get
- # their state IDs so we can resolve to a single state set.
- missing_state = new_state_groups - set(state_groups_map)
- if missing_state:
- group_to_state = yield self._get_state_for_groups(missing_state)
- state_groups_map.update(group_to_state)
-
- if len(new_state_groups) == 1:
- # If there is only one state group, then we know what the current
- # state is.
- return state_groups_map[new_state_groups.pop()], None
-
- # Ok, we need to defer to the state handler to resolve our state sets.
-
- state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- # We need to get the room version, which is in the create event.
- # Normally that'd be in the database, but its also possible that we're
- # currently trying to persist it.
- room_version = None
- for ev, _ in events_context:
- if ev.type == EventTypes.Create and ev.state_key == "":
- room_version = ev.content.get("room_version", "1")
- break
-
- if not room_version:
- room_version = yield self.get_room_version(room_id)
-
- logger.debug("calling resolve_state_groups from preserve_events")
- res = yield self._state_resolution_handler.resolve_state_groups(
- room_id,
- room_version,
- state_groups,
- events_map,
- state_res_store=StateResolutionStore(self),
- )
-
- return res.state, None
-
- @defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, current_state):
- """Calculate the new state deltas for a room.
-
- Assumes that we are only persisting events for one room at a time.
-
- Returns:
- tuple[list, dict] (to_delete, to_insert): where to_delete are the
- type/state_keys to remove from current_state_events and `to_insert`
- are the updates to current_state_events.
- """
- existing_state = yield self.get_current_state_ids(room_id)
-
- to_delete = [key for key in existing_state if key not in current_state]
-
- to_insert = {
- key: ev_id
- for key, ev_id in iteritems(current_state)
- if ev_id != existing_state.get(key)
- }
-
- return to_delete, to_insert
-
@log_function
def _persist_events_txn(
self,
@@ -1490,6 +932,13 @@ class EventsStore(
self._handle_event_relations(txn, event)
+ # Store the labels for this event.
+ labels = event.content.get(EventContentFields.LABELS)
+ if labels:
+ self.insert_labels_for_event_txn(
+ txn, event.event_id, labels, event.room_id, event.depth
+ )
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1683,7 +1132,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_messages", _count_messages)
@@ -1704,7 +1153,7 @@ class EventsStore(
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
@@ -1719,7 +1168,7 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
@@ -2204,7 +1653,7 @@ class EventsStore(
""",
(room_id,),
)
- min_depth, = txn.fetchone()
+ (min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -2396,7 +1845,6 @@ class EventsStore(
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
- "topics",
"users_in_public_rooms",
"users_who_share_private_rooms",
# no useful index, but let's clear them anyway
@@ -2439,12 +1887,11 @@ class EventsStore(
logger.info("[purge] done")
- @defer.inlineCallbacks
- def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
- to_1, so_1 = yield self._get_event_ordering(event_id1)
- to_2, so_2 = yield self._get_event_ordering(event_id2)
+ to_1, so_1 = await self._get_event_ordering(event_id1)
+ to_2, so_2 = await self._get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
@@ -2477,6 +1924,33 @@ class EventsStore(
get_all_updated_current_state_deltas_txn,
)
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
+ """
+ return self._simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 31ea6f917..51352b996 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not rows:
return 0
- upper_event_id, = rows[-1]
+ (upper_event_id,) = rows[-1]
# Update the redactions with the received_ts.
#
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index aeae5a2b2..b3a2771f1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index e6ee1e4aa..b41c3d317 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index cd95f1ce6..b520062d8 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -143,7 +143,7 @@ class PushRulesWorkerStore(
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return bool(count)
return self.runInteraction(
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f005c1ae0..d76861cdc 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore):
r["data"] = json.loads(dataJson)
except Exception as e:
- logger.warn(
+ logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
dataJson,
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 6c5b29288..f70d41eca 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE appservice_id IS NULL
"""
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index e47ab604d..2af24a20b 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- cache = self._get_joined_hosts_cache(room_id)
+ cache = yield self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
return joined_hosts
@@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
if not row or not row[0]:
return processed, True
- next_room, = row
+ (next_room,) = row
sql = """
UPDATE current_state_events
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
new file mode 100644
index 000000000..1d2ddb1b1
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* delete room keys that belong to deleted room key version, or to room key
+ * versions that don't exist (anymore)
+ */
+DELETE FROM e2e_room_keys
+WHERE version NOT IN (
+ SELECT version
+ FROM e2e_room_keys_versions
+ WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id
+ AND e2e_room_keys_versions.deleted = 0
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 000000000..5e29c1da1
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+ event_id TEXT,
+ label TEXT,
+ room_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 000000000..e8b1fd35d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ last_seen BIGINT,
+ ip TEXT,
+ user_agent TEXT,
+ hidden BOOLEAN DEFAULT 0,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 0e0849745..d1d7c6863 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
" ON event_search USING GIN (vector)"
)
except psycopg2.ProgrammingError as e:
- logger.warn(
+ logger.warning(
"Ignoring error %r when trying to switch from GIST to GIN", e
)
@@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
)
)
txn.execute(query, (value, search_query))
- headline, = txn.fetchall()[0]
+ (headline,) = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d54442e5f..313284803 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Iterable, Tuple
from six import iteritems, itervalues
from six.moves import range
@@ -23,6 +24,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
@@ -722,16 +725,18 @@ class StateGroupWorkerStore(
member_filter, non_member_filter = state_filter.get_member_split()
# Now we look them up in the member and non-member caches
- non_member_state, incomplete_groups_nm, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
)
- member_state, incomplete_groups_m, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
)
state = dict(non_member_state)
@@ -1073,7 +1078,7 @@ class StateBackgroundUpdateStore(
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- prev_group, = txn.fetchone()
+ (prev_group,) = txn.fetchone()
new_last_state_group = state_group
if prev_group:
@@ -1215,7 +1220,9 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
- def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 5ab639b2a..45b3de7d5 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -332,7 +332,7 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
- logger.info(
+ logger.debug(
"Updating %s stats for %s: %s", stats_type, stats_id, fields
)
self._update_stats_delta_txn(
@@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
(room_id,),
)
- current_state_events_count, = txn.fetchone()
+ (current_state_events_count,) = txn.fetchone()
users_in_room = self.get_users_in_room_txn(txn, room_id)
@@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
""",
(user_id,),
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count, pos
joined_rooms, pos = yield self.runInteraction(
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 263999dfc..616ef91d4 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
+ # We're only applying the "labels" filter on the database query, because applying the
+ # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+ # event_filter.check_fields apply it, which is not as efficient but makes the
+ # implementation simpler.
+ if event_filter.labels:
+ clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+ args.extend(event_filter.labels)
+
return " AND ".join(clauses), args
@@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.append(int(limit))
sql = (
- "SELECT event_id, topological_ordering, stream_ordering"
+ "SELECT DISTINCT event_id, topological_ordering, stream_ordering"
" FROM events"
+ " LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 000000000..fa03ca9ff
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,649 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import deque, namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+class _EventPeristenceQueue(object):
+ """Queues up events so that they can be persisted in bulk with only one
+ concurrent transaction per room.
+ """
+
+ _EventPersistQueueItem = namedtuple(
+ "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+ )
+
+ def __init__(self):
+ self._event_persist_queues = {}
+ self._currently_persisting_rooms = set()
+
+ def add_to_queue(self, room_id, events_and_contexts, backfilled):
+ """Add events to the queue, with the given persist_event options.
+
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
+ """
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+ if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ end_item = queue[-1]
+ if end_item.backfilled == backfilled:
+ end_item.events_and_contexts.extend(events_and_contexts)
+ return end_item.deferred.observe()
+
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+ queue.append(
+ self._EventPersistQueueItem(
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ )
+
+ return deferred.observe()
+
+ def handle_queue(self, room_id, per_item_callback):
+ """Attempts to handle the queue for a room if not already being handled.
+
+ The given callback will be invoked with for each item in the queue,
+ of type _EventPersistQueueItem. The per_item_callback will continuously
+ be called with new items, unless the queue becomnes empty. The return
+ value of the function will be given to the deferreds waiting on the item,
+ exceptions will be passed to the deferreds as well.
+
+ This function should therefore be called whenever anything is added
+ to the queue.
+
+ If another callback is currently handling the queue then it will not be
+ invoked.
+ """
+
+ if room_id in self._currently_persisting_rooms:
+ return
+
+ self._currently_persisting_rooms.add(room_id)
+
+ @defer.inlineCallbacks
+ def handle_queue_loop():
+ try:
+ queue = self._get_drainining_queue(room_id)
+ for item in queue:
+ try:
+ ret = yield per_item_callback(item)
+ except Exception:
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ finally:
+ queue = self._event_persist_queues.pop(room_id, None)
+ if queue:
+ self._event_persist_queues[room_id] = queue
+ self._currently_persisting_rooms.discard(room_id)
+
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
+
+ def _get_drainining_queue(self, room_id):
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ try:
+ while True:
+ yield queue.popleft()
+ except IndexError:
+ # Queue has been drained.
+ pass
+
+
+class EventsPersistenceStorage(object):
+ """High level interface for handling persisting newly received events.
+
+ Takes care of batching up events by room, and calculating the necessary
+ current state and forward extremity changes.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We ultimately want to split out the state store from the main store,
+ # so we use separate variables here even though they point to the same
+ # store for now.
+ self.main_store = stores.main
+ self.state_store = stores.main
+
+ self._clock = hs.get_clock()
+ self.is_mine_id = hs.is_mine_id
+ self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ @defer.inlineCallbacks
+ def persist_events(self, events_and_contexts, backfilled=False):
+ """
+ Write events to the database
+ Args:
+ events_and_contexts: list of tuples of (event, context)
+ backfilled (bool): Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ Deferred[int]: the stream ordering of the latest persisted event
+ """
+ partitioned = {}
+ for event, ctx in events_and_contexts:
+ partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+ deferreds = []
+ for room_id, evs_ctxs in iteritems(partitioned):
+ d = self._event_persist_queue.add_to_queue(
+ room_id, evs_ctxs, backfilled=backfilled
+ )
+ deferreds.append(d)
+
+ for room_id in partitioned:
+ self._maybe_start_persisting(room_id)
+
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+
+ return max_persisted_id
+
+ @defer.inlineCallbacks
+ def persist_event(self, event, context, backfilled=False):
+ """
+
+ Args:
+ event (EventBase):
+ context (EventContext):
+ backfilled (bool):
+
+ Returns:
+ Deferred: resolves to (int, int): the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
+ deferred = self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+
+ self._maybe_start_persisting(event.room_id)
+
+ yield make_deferred_yieldable(deferred)
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+ def _maybe_start_persisting(self, room_id):
+ @defer.inlineCallbacks
+ def persisting_queue(item):
+ with Measure(self._clock, "persist_events"):
+ yield self._persist_events(
+ item.events_and_contexts, backfilled=item.backfilled
+ )
+
+ self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+ @defer.inlineCallbacks
+ def _persist_events(self, events_and_contexts, backfilled=False):
+ """Calculates the change to current state and forward extremities, and
+ persists the given events and with those updates.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+ delete_existing (bool):
+
+ Returns:
+ Deferred: resolves when the events have been persisted
+ """
+ if not events_and_contexts:
+ return
+
+ chunks = [
+ events_and_contexts[x : x + 100]
+ for x in range(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+
+ # NB: Assumes that we are only persisting events for one room
+ # at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
+ new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
+ current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
+ if not backfilled:
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
+
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
+ latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+ room_id
+ )
+ new_latest_event_ids = yield self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ continue
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
+
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.info("Calculating state delta for room %s", room_id)
+ with Measure(
+ self._clock, "persist_events.get_new_state_after_events"
+ ):
+ res = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ state_delta_for_room[room_id] = ([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock, "persist_events.calculate_state_delta"
+ ):
+ delta = yield self._calculate_state_delta(
+ room_id, current_state
+ )
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+
+ yield self.main_store._persist_events_and_state_updates(
+ chunk,
+ current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ backfilled=backfilled,
+ )
+
+ @defer.inlineCallbacks
+ def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+ """Calculates the new forward extremities for a room given events to
+ persist.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
+ ]
+
+ latest_event_ids = set(latest_event_ids)
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(event.event_id for event in new_events)
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+
+ # Remove any events which are prev_events of any existing events.
+ existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
+
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = yield self.main_store._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
+ return result
+
+ @defer.inlineCallbacks
+ def _get_new_state_after_events(
+ self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+ ):
+ """Calculate the current state dict after adding some new events to
+ a room
+
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
+
+ Returns:
+ Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
+ """
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id,)
+ )
+ continue
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding.
+ for ev, ctx in events_context:
+ if event_id == ev.event_id and ctx.state_group is not None:
+ event_id_to_state_group[event_id] = ctx.state_group
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ missing_event_ids.add(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state groups for any missing events from DB
+ event_to_groups = yield self.main_store._get_state_group_for_events(
+ missing_event_ids
+ )
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
+
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
+
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return None, None
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ return new_state, delta_ids
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_groups_map[new_state_groups.pop()], None
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ # We need to get the room version, which is in the create event.
+ # Normally that'd be in the database, but its also possible that we're
+ # currently trying to persist it.
+ room_version = None
+ for ev, _ in events_context:
+ if ev.type == EventTypes.Create and ev.state_key == "":
+ room_version = ev.content.get("room_version", "1")
+ break
+
+ if not room_version:
+ room_version = yield self.main_store.get_room_version(room_id)
+
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = yield self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ return res.state, None
+
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ tuple[list, dict] (to_delete, to_insert): where to_delete are the
+ type/state_keys to remove from current_state_events and `to_insert`
+ are the updates to current_state_events.
+ """
+ existing_state = yield self.main_store.get_current_state_ids(room_id)
+
+ to_delete = [key for key in existing_state if key not in current_state]
+
+ to_insert = {
+ key: ev_id
+ for key, ev_id in iteritems(current_state)
+ if ev_id != existing_state.get(key)
+ }
+
+ return to_delete, to_insert
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a2df8fa82..373584689 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from six import iteritems, itervalues
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__)
@@ -322,3 +324,234 @@ class StateFilter(object):
)
return member_filter, non_member_filter
+
+
+class StateGroupStorage(object):
+ """High level interface to fetching state for event.
+ """
+
+ def __init__(self, hs, stores):
+ self.stores = stores
+
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+ (prev_group, delta_ids)
+ """
+
+ return self.stores.main.get_state_group_delta(state_group)
+
+ @defer.inlineCallbacks
+ def get_state_groups_ids(self, _room_id, event_ids):
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id (str): id of the room for these events
+ event_ids (iterable[str]): ids of the events
+
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ if not event_ids:
+ return {}
+
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(groups)
+
+ return group_to_state
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_group(self, state_group):
+ """Get the event IDs of all the state in the given state group
+
+ Args:
+ state_group (int)
+
+ Returns:
+ Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = yield self._get_state_for_groups((state_group,))
+
+ return group_to_state[state_group]
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, room_id, event_ids):
+ """ Get the state groups for the given list of event_ids
+ Returns:
+ Deferred[dict[int, list[EventBase]]]:
+ dict of state_group_id -> list of state events.
+ """
+ if not event_ids:
+ return {}
+
+ group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = yield self.stores.main.get_events(
+ [
+ ev_id
+ for group_ids in itervalues(group_to_ids)
+ for ev_id in itervalues(group_ids)
+ ],
+ get_prev_content=False,
+ )
+
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
+
+ def _get_state_groups_from_groups(self, groups, state_filter):
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups(list[int]): list of state group IDs to query
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+
+ return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+
+ @defer.inlineCallbacks
+ def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+ """Given a list of event_ids and type tuples, return a list of state
+ dicts for each event.
+ Args:
+ event_ids (list[string])
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+ """
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(
+ groups, state_filter
+ )
+
+ state_event_map = yield self.stores.main.get_events(
+ [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+ get_prev_content=False,
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in iteritems(group_to_state[group])
+ if v in state_event_map
+ }
+ for event_id, group in iteritems(event_to_groups)
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+ """
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
+
+ Args:
+ event_ids(list(str)): events whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from event_id -> (type, state_key) -> event_id
+ """
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+
+ groups = set(itervalues(event_to_groups))
+ group_to_state = yield self.stores.main._get_state_for_groups(
+ groups, state_filter
+ )
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in iteritems(event_to_groups)
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ @defer.inlineCallbacks
+ def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_for_events([event_id], state_filter)
+ return state_map[event_id]
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+ return state_map[event_id]
+
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ return self.stores.main._get_state_for_groups(groups, state_filter)
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ return self.stores.main.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
+ )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index cbb0a4810..9d851beaa 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- val, = cur.fetchone()
+ (val,) = cur.fetchone()
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 804dbca44..5c4de2e69 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -86,11 +86,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -105,7 +106,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -138,7 +139,7 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
+ func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
@@ -148,11 +149,10 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
@@ -309,7 +309,7 @@ class Linearizer(object):
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name,
key,
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 43fd65d69..da5077b47 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None):
if collect_callback:
collect_callback()
except Exception as e:
- logger.warn("Error calculating metrics for %s: %s", cache_name, e)
+ logger.warning("Error calculating metrics for %s: %s", cache_name, e)
raise
yield GaugeMetricFamily("__unused", "")
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5ac2530a6..0e8da27f5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
@@ -482,9 +482,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
def __init__(
@@ -618,7 +617,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b1bcdf23..328680432 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -119,7 +119,7 @@ class Measure(object):
context = LoggingContext.current_context()
if context != self.start_context:
- logger.warn(
+ logger.warning(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
self.start_context,
context,
@@ -128,7 +128,7 @@ class Measure(object):
return
if not context:
- logger.warn("Expected context. (%r)", self.name)
+ logger.warning("Expected context. (%r)", self.name)
return
current = context.get_resource_usage()
@@ -140,7 +140,7 @@ class Measure(object):
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
- logger.warn(
+ logger.warning(
"Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
)
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index 6c0f2bb0c..207cd17c2 100644
--- a/synapse/util/rlimit.py
+++ b/synapse/util/rlimit.py
@@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no):
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
- logger.warn("Failed to set file or core limit: %s", e)
+ logger.warning("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index fa404b9d7..ab7d03af3 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -42,6 +42,7 @@ def get_version_string(module):
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
+
try:
git_branch = (
subprocess.check_output(
@@ -51,7 +52,8 @@ def get_version_string(module):
.decode("ascii")
)
git_branch = "b=" + git_branch
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # FileNotFoundError can arise when git is not installed
git_branch = ""
try:
@@ -63,7 +65,7 @@ def get_version_string(module):
.decode("ascii")
)
git_tag = "t=" + git_tag
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_tag = ""
try:
@@ -74,7 +76,7 @@ def get_version_string(module):
.strip()
.decode("ascii")
)
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_commit = ""
try:
@@ -89,7 +91,7 @@ def get_version_string(module):
)
git_dirty = "dirty" if is_dirty else ""
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
diff --git a/synapse/visibility.py b/synapse/visibility.py
index bf0f1eebd..8c843febd 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.storage import Storage
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
@@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
def filter_events_for_client(
- store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+ storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
):
"""
Check which events a user is allowed to see
Args:
- store (synapse.storage.DataStore): our datastore (can also be a worker
- store)
+ storage
user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if:
@@ -68,12 +68,12 @@ def filter_events_for_client(
events = list(e for e in events if not e.internal_metadata.is_soft_failed())
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
- event_id_to_state = yield store.get_state_for_events(
+ event_id_to_state = yield storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
+ ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id
)
@@ -84,7 +84,7 @@ def filter_events_for_client(
else []
)
- erased_senders = yield store.are_users_erased((e.sender for e in events))
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
def allowed(event):
"""
@@ -213,13 +213,17 @@ def filter_events_for_client(
@defer.inlineCallbacks
def filter_events_for_server(
- store, server_name, events, redact=True, check_history_visibility_only=False
+ storage: Storage,
+ server_name,
+ events,
+ redact=True,
+ check_history_visibility_only=False,
):
"""Filter a list of events based on whether given server is allowed to
see them.
Args:
- store (DataStore)
+ storage
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
@@ -274,7 +278,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),)
@@ -292,14 +296,14 @@ def filter_events_for_server(
if not visibility_ids:
all_open = True
else:
- event_map = yield store.get_events(visibility_ids)
+ event_map = yield storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map)
)
if not check_history_visibility_only:
- erased_senders = yield store.are_users_erased((e.sender for e in events))
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
@@ -328,7 +332,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@@ -358,7 +362,7 @@ def filter_events_for_server(
return False
return state_key[idx + 1 :] == server_name
- event_map = yield store.get_events(
+ event_map = yield storage.main.get_events(
[
e_id
for e_id, key in iteritems(event_id_to_state_key)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de1..2dc505224 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -19,6 +19,7 @@ import jsonschema
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
@@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
+ "org.matrix.labels": ["#fun"],
+ "org.matrix.not_labels": ["#work"],
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
@@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
)
self.assertFalse(Filter(definition).check(event))
+ def test_filter_labels(self):
+ definition = {"org.matrix.labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ def test_filter_not_labels(self):
+ definition = {"org.matrix.not_labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3d..8efd39c7f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 67f101305..5ec568f4e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
- "get_devices_by_remote",
+ "get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ None
+ )
def test_started_typing_local(self):
self.room_members = [U_APPLE, U_BANANA]
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba646..2096ba3c9 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d702526..cfcd98ff7 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 000000000..22abf7651
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62d..af2327fb6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cdb..4f924ce45 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
+ self.storage = hs.get_storage()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b4..b68e9fe08 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.get_success(
+ self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ )
self.replicate()
event_source = RoomEventSource(self.hs)
@@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
+ self.storage.persistence.persist_events(
+ [(event, context)], backfilled=True
+ )
)
else:
- self.get_success(self.master_store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index d3a4f717f..8e1ca8b73 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ ):
+ count = self.get_success(
+ self.store._simple_select_one_onecol(
+ table="events",
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f2ca7461..5e38fd6ce 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+ def test_filter_labels(self):
+ """Test that we can filter by a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ message_filter = json.dumps(
+ {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+ )
+
+ events = self._test_filter_labels(message_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ )
+
+ events = self._test_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_filter_labels(self, message_filter):
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ )
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
+ % (self.room_id, token, message_filter),
+ )
+ self.render(request)
+
+ return channel.json_body["chunk"]
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b7..8ea0cb05e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -106,13 +106,22 @@ class RestHelper(object):
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
- path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = {"msgtype": "m.text", "body": body}
+
+ return self.send_event(
+ room_id, "m.room.message", content, txn_id, tok, expect_code
+ )
+
+ def send_event(
+ self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ ):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094b..3283c0e47 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from mock import Mock
import synapse.rest.admin
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
@@ -26,7 +28,12 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
- servlets = [sync.register_servlets]
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def make_homeserver(self, reactor, clock):
@@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_sync_filter_labels(self):
+ """Test that we can filter by a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_sync_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_sync_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_sync_filter_labels(self, sync_filter):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ request, channel = self.make_request(
+ "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
+
class SyncTypingTests(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/server.py b/tests/server.py
index e397ebe8f..f878aeaad 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -161,7 +161,11 @@ def make_request(
path = path.encode("ascii")
# Decorate it to be the full path, if we're using shorthand
- if shorthand and not path.startswith(b"/_matrix"):
+ if (
+ shorthand
+ and not path.startswith(b"/_matrix")
+ and not path.startswith(b"/_synapse")
+ ):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -391,11 +395,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -426,6 +443,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -470,6 +490,10 @@ class FakeTransport(object):
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
+
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
"""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a1452..9b81b536f 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d))
- self.assertEquals(a.func("foo"), d.result)
+ self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1..6f8d99095 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
- def test_get_devices_by_remote(self):
+ def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
- def test_get_devices_by_remote_limited(self):
+ def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
#
# first we should get a single update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
new file mode 100644
index 000000000..d128fde44
--- /dev/null
+++ b/tests/storage/test_e2e_room_keys.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+# sample room_key data for use in the tests
+room_key = {
+ "first_message_index": 1,
+ "forwarded_count": 1,
+ "is_verified": False,
+ "session_data": "SSBBTSBBIEZJU0gK",
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_room_keys_version_delete(self):
+ # test that deleting a room key backup deletes the keys
+ version1 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.set_e2e_room_key(
+ "user_id", version1, "room", "session", room_key
+ )
+ )
+
+ version2 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.set_e2e_room_key(
+ "user_id", version2, "room", "session", room_key
+ )
+ )
+
+ # make sure the keys were stored properly
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ # delete version1
+ self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1))
+
+ # make sure the key from version1 is gone, and the key from version2 is
+ # still there
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 0)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 427d3c49c..4561c3e38 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.store.persist_event(event_1, context_1))
+ self.get_success(self.storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.store.persist_event(event_2, context_2))
+ self.get_success(self.storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706..3ddaa151f 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.store.persist_event(
+ yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ff..9ddd17f73 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2d..43200654f 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_datastore = self.store
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -63,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
builder
)
- yield self.store.persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@@ -82,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups_ids(
+ state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1)
@@ -101,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+ state_group_map = yield self.storage.state.get_state_groups(
+ self.room, [e2.event_id]
+ )
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -141,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.store.get_state_for_event(e5.event_id)
+ state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4)
@@ -157,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -181,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+ group_ids = yield self.storage.state.get_state_groups_ids(
+ room_id, [e5.event_id]
+ )
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -267,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -287,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
@@ -346,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
state_dict_ids.pop((e2.type, e2.state_key))
- self.store._state_group_cache.invalidate(group)
- self.store._state_group_cache.update(
- sequence=self.store._state_group_cache.sequence,
+ self.state_datastore._state_group_cache.invalidate(group)
+ self.state_datastore._state_group_cache.update(
+ sequence=self.state_datastore._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),),
)
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@@ -369,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -405,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -424,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index a73f18f88..7d82b5846 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -58,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
)
self.handler = self.homeserver.get_handlers().federation_handler
- self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+ context
+ )
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
@@ -75,7 +78,8 @@ class MessageAcceptTests(unittest.TestCase):
self.assertEqual(
self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0],
"$join:test.serv",
@@ -97,7 +101,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -137,6 +142,6 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
new file mode 100644
index 000000000..7657bddea
--- /dev/null
+++ b/tests/test_phone_home.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+
+import mock
+
+from synapse.app.homeserver import phone_stats_home
+
+from tests.unittest import HomeserverTestCase
+
+
+class PhoneHomeStatsTestCase(HomeserverTestCase):
+ def test_performance_frozen_clock(self):
+ """
+ If time doesn't move, don't error out.
+ """
+ past_stats = [
+ (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ ]
+ stats = {}
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertEqual(stats["cpu_average"], 0)
+
+ def test_performance_100(self):
+ """
+ 1 second of usage over 1 second is 100% CPU usage.
+ """
+ real_res = resource.getrusage(resource.RUSAGE_SELF)
+ old_resource = mock.Mock(spec=real_res)
+ old_resource.ru_utime = real_res.ru_utime - 1
+ old_resource.ru_stime = real_res.ru_stime
+ old_resource.ru_maxrss = real_res.ru_maxrss
+
+ past_stats = [(self.hs.get_clock().time(), old_resource)]
+ stats = {}
+ self.reactor.advance(1)
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb4..38246555b 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -158,10 +158,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
+ storage = Mock(main=self.store, state=self.store)
hs = Mock(
spec_set=[
"config",
"get_datastore",
+ "get_storage",
"get_auth",
"get_state_handler",
"get_clock",
@@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+ hs.get_storage.return_value = storage
self.state = StateHandler(hs)
self.event_id = 0
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035..f7381b288 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+from mock import Mock
+
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens.
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering")
start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 5713870f4..39e360fe2 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
@@ -325,9 +325,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
diff --git a/tests/utils.py b/tests/utils.py
index 8cced4b7e..7dc9bdc50 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -325,10 +325,16 @@ def setup_test_homeserver(
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
else:
+ # If we have been given an explicit datastore we probably want to mock
+ # out the DataStores somehow too. This all feels a bit wrong, but then
+ # mocking the stores feels wrong too.
+ datastores = Mock(datastore=datastore)
+
hs = homeserverToUse(
name,
db_pool=None,
datastore=datastore,
+ datastores=datastores,
config=config,
version_string="Synapse/tests",
database_engine=db_engine,
@@ -646,7 +652,7 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
- store = hs.get_datastore()
+ persistence_store = hs.get_storage().persistence
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
@@ -663,4 +669,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
diff --git a/tox.ini b/tox.ini
index e3a53f340..afe9bc909 100644
--- a/tox.ini
+++ b/tox.ini
@@ -114,16 +114,16 @@ skip_install = True
basepython = python3.6
deps =
flake8
- black==19.3b0 # We pin so that our tests don't start failing on new releases of black.
+ black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
commands =
python -m black --check --diff .
- /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
+ /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
skip_install = True
deps = isort
-commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests"
+commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts"
[testenv:check-newsfragment]
skip_install = True
@@ -167,6 +167,6 @@ deps =
env =
MYPYPATH = stubs/
extras = all
-commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
+commands = mypy \
synapse/logging/ \
synapse/config/