diff --git a/CHANGES.md b/CHANGES.md index 490c2021e..0ffdf1aae 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,169 @@ +Synapse 0.99.5.2 (2019-05-30) +============================= + +Bugfixes +-------- + +- Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. ([\#5274](https://github.com/matrix-org/synapse/issues/5274), [\#5278](https://github.com/matrix-org/synapse/issues/5278), [\#5291](https://github.com/matrix-org/synapse/issues/5291)) + + +Synapse 0.99.5.1 (2019-05-22) +============================= + +0.99.5.1 supersedes 0.99.5 due to malformed debian changelog - no functional changes. + +Synapse 0.99.5 (2019-05-22) +=========================== + +No significant changes. + + +Synapse 0.99.5rc1 (2019-05-21) +============================== + +Features +-------- + +- Add ability to blacklist IP ranges for the federation client. ([\#5043](https://github.com/matrix-org/synapse/issues/5043)) +- Ratelimiting configuration for clients sending messages and the federation server has been altered to match login ratelimiting. The old configuration names will continue working. Check the sample config for details of the new names. ([\#5181](https://github.com/matrix-org/synapse/issues/5181)) +- Drop support for the undocumented /_matrix/client/v2_alpha API prefix. ([\#5190](https://github.com/matrix-org/synapse/issues/5190)) +- Add an option to disable per-room profiles. ([\#5196](https://github.com/matrix-org/synapse/issues/5196)) +- Stick an expiration date to any registered user missing one at startup if account validity is enabled. ([\#5204](https://github.com/matrix-org/synapse/issues/5204)) +- Add experimental support for relations (aka reactions and edits). ([\#5209](https://github.com/matrix-org/synapse/issues/5209), [\#5211](https://github.com/matrix-org/synapse/issues/5211), [\#5203](https://github.com/matrix-org/synapse/issues/5203), [\#5212](https://github.com/matrix-org/synapse/issues/5212)) +- Add a room version 4 which uses a new event ID format, as per [MSC2002](https://github.com/matrix-org/matrix-doc/pull/2002). ([\#5210](https://github.com/matrix-org/synapse/issues/5210), [\#5217](https://github.com/matrix-org/synapse/issues/5217)) + + +Bugfixes +-------- + +- Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill. ([\#5039](https://github.com/matrix-org/synapse/issues/5039)) +- Exclude soft-failed events from forward-extremity candidates: fixes "No forward extremities left!" error. ([\#5146](https://github.com/matrix-org/synapse/issues/5146)) +- Re-order stages in registration flows such that msisdn and email verification are done last. ([\#5174](https://github.com/matrix-org/synapse/issues/5174)) +- Fix 3pid guest invites. ([\#5177](https://github.com/matrix-org/synapse/issues/5177)) +- Fix a bug where the register endpoint would fail with M_THREEPID_IN_USE instead of returning an account previously registered in the same session. ([\#5187](https://github.com/matrix-org/synapse/issues/5187)) +- Prevent registration for user ids that are too long to fit into a state key. Contributed by Reid Anderson. ([\#5198](https://github.com/matrix-org/synapse/issues/5198)) +- Fix incompatibility between ACME support and Python 3.5.2. ([\#5218](https://github.com/matrix-org/synapse/issues/5218)) +- Fix error handling for rooms whose versions are unknown. ([\#5219](https://github.com/matrix-org/synapse/issues/5219)) + + +Internal Changes +---------------- + +- Make /sync attempt to return device updates for both joined and invited users. Note that this doesn't currently work correctly due to other bugs. ([\#3484](https://github.com/matrix-org/synapse/issues/3484)) +- Update tests to consistently be configured via the same code that is used when loading from configuration files. ([\#5171](https://github.com/matrix-org/synapse/issues/5171), [\#5185](https://github.com/matrix-org/synapse/issues/5185)) +- Allow client event serialization to be async. ([\#5183](https://github.com/matrix-org/synapse/issues/5183)) +- Expose DataStore._get_events as get_events_as_list. ([\#5184](https://github.com/matrix-org/synapse/issues/5184)) +- Make generating SQL bounds for pagination generic. ([\#5191](https://github.com/matrix-org/synapse/issues/5191)) +- Stop telling people to install the optional dependencies by default. ([\#5197](https://github.com/matrix-org/synapse/issues/5197)) + + +Synapse 0.99.4 (2019-05-15) +=========================== + +No significant changes. + + +Synapse 0.99.4rc1 (2019-05-13) +============================== + +Features +-------- + +- Add systemd-python to the optional dependencies to enable logging to the systemd journal. Install with `pip install matrix-synapse[systemd]`. ([\#4339](https://github.com/matrix-org/synapse/issues/4339)) +- Add a default .m.rule.tombstone push rule. ([\#4867](https://github.com/matrix-org/synapse/issues/4867)) +- Add ability for password provider modules to bind email addresses to users upon registration. ([\#4947](https://github.com/matrix-org/synapse/issues/4947)) +- Implementation of [MSC1711](https://github.com/matrix-org/matrix-doc/pull/1711) including config options for requiring valid TLS certificates for federation traffic, the ability to disable TLS validation for specific domains, and the ability to specify your own list of CA certificates. ([\#4967](https://github.com/matrix-org/synapse/issues/4967)) +- Remove presence list support as per MSC 1819. ([\#4989](https://github.com/matrix-org/synapse/issues/4989)) +- Reduce CPU usage starting pushers during start up. ([\#4991](https://github.com/matrix-org/synapse/issues/4991)) +- Add a delete group admin API. ([\#5002](https://github.com/matrix-org/synapse/issues/5002)) +- Add config option to block users from looking up 3PIDs. ([\#5010](https://github.com/matrix-org/synapse/issues/5010)) +- Add context to phonehome stats. ([\#5020](https://github.com/matrix-org/synapse/issues/5020)) +- Configure the example systemd units to have a log identifier of `matrix-synapse` + instead of the executable name, `python`. + Contributed by Christoph Müller. ([\#5023](https://github.com/matrix-org/synapse/issues/5023)) +- Add time-based account expiration. ([\#5027](https://github.com/matrix-org/synapse/issues/5027), [\#5047](https://github.com/matrix-org/synapse/issues/5047), [\#5073](https://github.com/matrix-org/synapse/issues/5073), [\#5116](https://github.com/matrix-org/synapse/issues/5116)) +- Add support for handling `/versions`, `/voip` and `/push_rules` client endpoints to client_reader worker. ([\#5063](https://github.com/matrix-org/synapse/issues/5063), [\#5065](https://github.com/matrix-org/synapse/issues/5065), [\#5070](https://github.com/matrix-org/synapse/issues/5070)) +- Add a configuration option to require authentication on /publicRooms and /profile endpoints. ([\#5083](https://github.com/matrix-org/synapse/issues/5083)) +- Move admin APIs to `/_synapse/admin/v1`. (The old paths are retained for backwards-compatibility, for now). ([\#5119](https://github.com/matrix-org/synapse/issues/5119)) +- Implement an admin API for sending server notices. Many thanks to @krombel who provided a foundation for this work. ([\#5121](https://github.com/matrix-org/synapse/issues/5121), [\#5142](https://github.com/matrix-org/synapse/issues/5142)) + + +Bugfixes +-------- + +- Avoid redundant URL encoding of redirect URL for SSO login in the fallback login page. Fixes a regression introduced in [#4220](https://github.com/matrix-org/synapse/pull/4220). Contributed by Marcel Fabian Krüger ("[zaugin](https://github.com/zauguin)"). ([\#4555](https://github.com/matrix-org/synapse/issues/4555)) +- Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server. ([\#4942](https://github.com/matrix-org/synapse/issues/4942), [\#5103](https://github.com/matrix-org/synapse/issues/5103)) +- Fix sync bug which made accepting invites unreliable in worker-mode synapses. ([\#4955](https://github.com/matrix-org/synapse/issues/4955), [\#4956](https://github.com/matrix-org/synapse/issues/4956)) +- start.sh: Fix the --no-rate-limit option for messages and make it bypass rate limit on registration and login too. ([\#4981](https://github.com/matrix-org/synapse/issues/4981)) +- Transfer related groups on room upgrade. ([\#4990](https://github.com/matrix-org/synapse/issues/4990)) +- Prevent the ability to kick users from a room they aren't in. ([\#4999](https://github.com/matrix-org/synapse/issues/4999)) +- Fix issue #4596 so synapse_port_db script works with --curses option on Python 3. Contributed by Anders Jensen-Waud . ([\#5003](https://github.com/matrix-org/synapse/issues/5003)) +- Clients timing out/disappearing while downloading from the media repository will now no longer log a spurious "Producer was not unregistered" message. ([\#5009](https://github.com/matrix-org/synapse/issues/5009)) +- Fix "cannot import name execute_batch" error with postgres. ([\#5032](https://github.com/matrix-org/synapse/issues/5032)) +- Fix disappearing exceptions in manhole. ([\#5035](https://github.com/matrix-org/synapse/issues/5035)) +- Workaround bug in twisted where attempting too many concurrent DNS requests could cause it to hang due to running out of file descriptors. ([\#5037](https://github.com/matrix-org/synapse/issues/5037)) +- Make sure we're not registering the same 3pid twice on registration. ([\#5071](https://github.com/matrix-org/synapse/issues/5071)) +- Don't crash on lack of expiry templates. ([\#5077](https://github.com/matrix-org/synapse/issues/5077)) +- Fix the ratelimiting on third party invites. ([\#5104](https://github.com/matrix-org/synapse/issues/5104)) +- Add some missing limitations to room alias creation. ([\#5124](https://github.com/matrix-org/synapse/issues/5124), [\#5128](https://github.com/matrix-org/synapse/issues/5128)) +- Limit the number of EDUs in transactions to 100 as expected by synapse. Thanks to @superboum for this work! ([\#5138](https://github.com/matrix-org/synapse/issues/5138)) + +Internal Changes +---------------- + +- Add test to verify threepid auth check added in #4435. ([\#4474](https://github.com/matrix-org/synapse/issues/4474)) +- Fix/improve some docstrings in the replication code. ([\#4949](https://github.com/matrix-org/synapse/issues/4949)) +- Split synapse.replication.tcp.streams into smaller files. ([\#4953](https://github.com/matrix-org/synapse/issues/4953)) +- Refactor replication row generation/parsing. ([\#4954](https://github.com/matrix-org/synapse/issues/4954)) +- Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`. ([\#4959](https://github.com/matrix-org/synapse/issues/4959)) +- Remove log line for password via the admin API. ([\#4965](https://github.com/matrix-org/synapse/issues/4965)) +- Fix typo in TLS filenames in docker/README.md. Also add the '-p' commandline option to the 'docker run' example. Contributed by Jurrie Overgoor. ([\#4968](https://github.com/matrix-org/synapse/issues/4968)) +- Refactor room version definitions. ([\#4969](https://github.com/matrix-org/synapse/issues/4969)) +- Reduce log level of .well-known/matrix/client responses. ([\#4972](https://github.com/matrix-org/synapse/issues/4972)) +- Add `config.signing_key_path` that can be read by `synapse.config` utility. ([\#4974](https://github.com/matrix-org/synapse/issues/4974)) +- Track which identity server is used when binding a threepid and use that for unbinding, as per MSC1915. ([\#4982](https://github.com/matrix-org/synapse/issues/4982)) +- Rewrite KeyringTestCase as a HomeserverTestCase. ([\#4985](https://github.com/matrix-org/synapse/issues/4985)) +- README updates: Corrected the default POSTGRES_USER. Added port forwarding hint in TLS section. ([\#4987](https://github.com/matrix-org/synapse/issues/4987)) +- Remove a number of unused tables from the database schema. ([\#4992](https://github.com/matrix-org/synapse/issues/4992), [\#5028](https://github.com/matrix-org/synapse/issues/5028), [\#5033](https://github.com/matrix-org/synapse/issues/5033)) +- Run `black` on the remainder of `synapse/storage/`. ([\#4996](https://github.com/matrix-org/synapse/issues/4996)) +- Fix grammar in get_current_users_in_room and give it a docstring. ([\#4998](https://github.com/matrix-org/synapse/issues/4998)) +- Clean up some code in the server-key Keyring. ([\#5001](https://github.com/matrix-org/synapse/issues/5001)) +- Convert SYNAPSE_NO_TLS Docker variable to boolean for user friendliness. Contributed by Gabriel Eckerson. ([\#5005](https://github.com/matrix-org/synapse/issues/5005)) +- Refactor synapse.storage._base._simple_select_list_paginate. ([\#5007](https://github.com/matrix-org/synapse/issues/5007)) +- Store the notary server name correctly in server_keys_json. ([\#5024](https://github.com/matrix-org/synapse/issues/5024)) +- Rewrite Datastore.get_server_verify_keys to reduce the number of database transactions. ([\#5030](https://github.com/matrix-org/synapse/issues/5030)) +- Remove extraneous period from copyright headers. ([\#5046](https://github.com/matrix-org/synapse/issues/5046)) +- Update documentation for where to get Synapse packages. ([\#5067](https://github.com/matrix-org/synapse/issues/5067)) +- Add workarounds for pep-517 install errors. ([\#5098](https://github.com/matrix-org/synapse/issues/5098)) +- Improve logging when event-signature checks fail. ([\#5100](https://github.com/matrix-org/synapse/issues/5100)) +- Factor out an "assert_requester_is_admin" function. ([\#5120](https://github.com/matrix-org/synapse/issues/5120)) +- Remove the requirement to authenticate for /admin/server_version. ([\#5122](https://github.com/matrix-org/synapse/issues/5122)) +- Prevent an exception from being raised in a IResolutionReceiver and use a more generic error message for blacklisted URL previews. ([\#5155](https://github.com/matrix-org/synapse/issues/5155)) +- Run `black` on the tests directory. ([\#5170](https://github.com/matrix-org/synapse/issues/5170)) +- Fix CI after new release of isort. ([\#5179](https://github.com/matrix-org/synapse/issues/5179)) +- Fix bogus imports in unit tests. ([\#5154](https://github.com/matrix-org/synapse/issues/5154)) + + +Synapse 0.99.3.2 (2019-05-03) +============================= + +Internal Changes +---------------- + +- Ensure that we have `urllib3` <1.25, to resolve incompatibility with `requests`. ([\#5135](https://github.com/matrix-org/synapse/issues/5135)) + + +Synapse 0.99.3.1 (2019-05-03) +============================= + +Security update +--------------- + +This release includes two security fixes: + +- Switch to using a cryptographically-secure random number generator for token strings, ensuring they cannot be predicted by an attacker. Thanks to @opnsec for identifying and responsibly disclosing this issue! ([\#5133](https://github.com/matrix-org/synapse/issues/5133)) +- Blacklist 0.0.0.0 and :: by default for URL previews. Thanks to @opnsec for identifying and responsibly disclosing this issue too! ([\#5134](https://github.com/matrix-org/synapse/issues/5134)) + Synapse 0.99.3 (2019-04-01) =========================== diff --git a/INSTALL.md b/INSTALL.md index d55a1f89a..193459314 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -35,7 +35,7 @@ virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip pip install --upgrade setuptools -pip install matrix-synapse[all] +pip install matrix-synapse ``` This will download Synapse from [PyPI](https://pypi.org/project/matrix-synapse) @@ -48,7 +48,7 @@ update flag: ``` source ~/synapse/env/bin/activate -pip install -U matrix-synapse[all] +pip install -U matrix-synapse ``` Before you can start Synapse, you will need to generate a configuration @@ -257,9 +257,8 @@ https://github.com/spantaleev/matrix-docker-ansible-deploy #### Matrix.org packages Matrix.org provides Debian/Ubuntu packages of the latest stable version of -Synapse via https://packages.matrix.org/debian/. To use them: - -For Debian 9 (Stretch), Ubuntu 16.04 (Xenial), and later: +Synapse via https://packages.matrix.org/debian/. They are available for Debian +9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: ``` sudo apt install -y lsb-release wget apt-transport-https @@ -270,17 +269,6 @@ sudo apt update sudo apt install matrix-synapse-py3 ``` -For Debian 8 (Jessie): - -``` -sudo apt install -y lsb-release wget apt-transport-https -sudo wget -O /etc/apt/trusted.gpg.d/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg -echo "deb [signed-by=5586CCC0CBBBEFC7A25811ADF473DD4473365DE1] https://packages.matrix.org/debian/ $(lsb_release -cs) main" | - sudo tee /etc/apt/sources.list.d/matrix-org.list -sudo apt update -sudo apt install matrix-synapse-py3 -``` - **Note**: if you followed a previous version of these instructions which recommended using `apt-key add` to add an old key from `https://matrix.org/packages/debian/`, you should note that this key has been @@ -288,6 +276,9 @@ revoked. You should remove the old key with `sudo apt-key remove C35EB17E1EAE708E6603A9B3AD0592FE47F0DF61`, and follow the above instructions to update your configuration. +The fingerprint of the repository signing key (as shown by `gpg +/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is +`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. #### Downstream Debian/Ubuntu packages diff --git a/README.rst b/README.rst index 24afb93d7..5409f0c56 100644 --- a/README.rst +++ b/README.rst @@ -173,7 +173,7 @@ Synapse offers two database engines: * `PostgreSQL `_ By default Synapse uses SQLite in and doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with +SQLite is only recommended in Synapse for testing purposes or for servers with light workloads. Almost all installations should opt to use PostreSQL. Advantages include: @@ -272,7 +272,7 @@ to install using pip and a virtualenv:: virtualenv -p python3 env source env/bin/activate - python -m pip install -e .[all] + python -m pip install --no-pep-517 -e .[all] This will run a process of downloading and installing all the needed dependencies into a virtual env. diff --git a/changelog.d/4338.feature b/changelog.d/4338.feature new file mode 100644 index 000000000..01285e965 --- /dev/null +++ b/changelog.d/4338.feature @@ -0,0 +1 @@ +Synapse now more efficiently collates room statistics. diff --git a/changelog.d/4339.feature b/changelog.d/4339.feature deleted file mode 100644 index cecff97b8..000000000 --- a/changelog.d/4339.feature +++ /dev/null @@ -1 +0,0 @@ -Add systemd-python to the optional dependencies to enable logging to the systemd journal. Install with `pip install matrix-synapse[systemd]`. diff --git a/changelog.d/4474.misc b/changelog.d/4474.misc deleted file mode 100644 index 4b882d60b..000000000 --- a/changelog.d/4474.misc +++ /dev/null @@ -1 +0,0 @@ -Add test to verify threepid auth check added in #4435. diff --git a/changelog.d/4555.bugfix b/changelog.d/4555.bugfix deleted file mode 100644 index d596022c3..000000000 --- a/changelog.d/4555.bugfix +++ /dev/null @@ -1 +0,0 @@ -Avoid redundant URL encoding of redirect URL for SSO login in the fallback login page. Fixes a regression introduced in [#4220](https://github.com/matrix-org/synapse/pull/4220). Contributed by Marcel Fabian Krüger ("[zaugin](https://github.com/zauguin)"). diff --git a/changelog.d/4942.bugfix b/changelog.d/4942.bugfix deleted file mode 100644 index 590d80d58..000000000 --- a/changelog.d/4942.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server. diff --git a/changelog.d/4947.feature b/changelog.d/4947.feature deleted file mode 100644 index b9d27b90f..000000000 --- a/changelog.d/4947.feature +++ /dev/null @@ -1 +0,0 @@ -Add ability for password provider modules to bind email addresses to users upon registration. \ No newline at end of file diff --git a/changelog.d/4949.misc b/changelog.d/4949.misc deleted file mode 100644 index 25c4e05a6..000000000 --- a/changelog.d/4949.misc +++ /dev/null @@ -1 +0,0 @@ -Fix/improve some docstrings in the replication code. diff --git a/changelog.d/4953.misc b/changelog.d/4953.misc deleted file mode 100644 index 06a084e6e..000000000 --- a/changelog.d/4953.misc +++ /dev/null @@ -1,2 +0,0 @@ -Split synapse.replication.tcp.streams into smaller files. - diff --git a/changelog.d/4954.misc b/changelog.d/4954.misc deleted file mode 100644 index 91f145950..000000000 --- a/changelog.d/4954.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor replication row generation/parsing. diff --git a/changelog.d/4955.bugfix b/changelog.d/4955.bugfix deleted file mode 100644 index e50e67383..000000000 --- a/changelog.d/4955.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix sync bug which made accepting invites unreliable in worker-mode synapses. diff --git a/changelog.d/4956.bugfix b/changelog.d/4956.bugfix deleted file mode 100644 index e50e67383..000000000 --- a/changelog.d/4956.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix sync bug which made accepting invites unreliable in worker-mode synapses. diff --git a/changelog.d/4959.misc b/changelog.d/4959.misc deleted file mode 100644 index dd4275501..000000000 --- a/changelog.d/4959.misc +++ /dev/null @@ -1 +0,0 @@ -Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`. \ No newline at end of file diff --git a/changelog.d/4965.misc b/changelog.d/4965.misc deleted file mode 100644 index 284c58b75..000000000 --- a/changelog.d/4965.misc +++ /dev/null @@ -1 +0,0 @@ -Remove log line for password via the admin API. diff --git a/changelog.d/4968.misc b/changelog.d/4968.misc deleted file mode 100644 index 7a7b69771..000000000 --- a/changelog.d/4968.misc +++ /dev/null @@ -1 +0,0 @@ -Fix typo in TLS filenames in docker/README.md. Also add the '-p' commandline option to the 'docker run' example. Contributed by Jurrie Overgoor. diff --git a/changelog.d/4969.misc b/changelog.d/4969.misc deleted file mode 100644 index e3a3214e6..000000000 --- a/changelog.d/4969.misc +++ /dev/null @@ -1,2 +0,0 @@ -Refactor room version definitions. - diff --git a/changelog.d/4974.misc b/changelog.d/4974.misc deleted file mode 100644 index 672a18923..000000000 --- a/changelog.d/4974.misc +++ /dev/null @@ -1 +0,0 @@ -Add `config.signing_key_path` that can be read by `synapse.config` utility. diff --git a/changelog.d/4981.bugfix b/changelog.d/4981.bugfix deleted file mode 100644 index e51b45eec..000000000 --- a/changelog.d/4981.bugfix +++ /dev/null @@ -1 +0,0 @@ -start.sh: Fix the --no-rate-limit option for messages and make it bypass rate limit on registration and login too. \ No newline at end of file diff --git a/changelog.d/4982.misc b/changelog.d/4982.misc deleted file mode 100644 index 067c177d3..000000000 --- a/changelog.d/4982.misc +++ /dev/null @@ -1 +0,0 @@ -Track which identity server is used when binding a threepid and use that for unbinding, as per MSC1915. diff --git a/changelog.d/4985.misc b/changelog.d/4985.misc deleted file mode 100644 index 50c9ff9e0..000000000 --- a/changelog.d/4985.misc +++ /dev/null @@ -1 +0,0 @@ -Rewrite KeyringTestCase as a HomeserverTestCase. diff --git a/changelog.d/4987.misc b/changelog.d/4987.misc deleted file mode 100644 index 33490e146..000000000 --- a/changelog.d/4987.misc +++ /dev/null @@ -1 +0,0 @@ -README updates: Corrected the default POSTGRES_USER. Added port forwarding hint in TLS section. diff --git a/changelog.d/4989.feature b/changelog.d/4989.feature deleted file mode 100644 index a5138f561..000000000 --- a/changelog.d/4989.feature +++ /dev/null @@ -1 +0,0 @@ -Remove presence list support as per MSC 1819. diff --git a/changelog.d/4990.bugfix b/changelog.d/4990.bugfix deleted file mode 100644 index 1b69d058f..000000000 --- a/changelog.d/4990.bugfix +++ /dev/null @@ -1 +0,0 @@ -Transfer related groups on room upgrade. \ No newline at end of file diff --git a/changelog.d/4991.feature b/changelog.d/4991.feature deleted file mode 100644 index 034bf3239..000000000 --- a/changelog.d/4991.feature +++ /dev/null @@ -1 +0,0 @@ -Reduce CPU usage starting pushers during start up. diff --git a/changelog.d/4992.misc b/changelog.d/4992.misc deleted file mode 100644 index 3ee4228c0..000000000 --- a/changelog.d/4992.misc +++ /dev/null @@ -1 +0,0 @@ -Remove a number of unused tables from the database schema. diff --git a/changelog.d/4996.misc b/changelog.d/4996.misc deleted file mode 100644 index ecac24e2b..000000000 --- a/changelog.d/4996.misc +++ /dev/null @@ -1 +0,0 @@ -Run `black` on the remainder of `synapse/storage/`. \ No newline at end of file diff --git a/changelog.d/4998.misc b/changelog.d/4998.misc deleted file mode 100644 index 7caf95913..000000000 --- a/changelog.d/4998.misc +++ /dev/null @@ -1 +0,0 @@ -Fix grammar in get_current_users_in_room and give it a docstring. diff --git a/changelog.d/4999.bugfix b/changelog.d/4999.bugfix deleted file mode 100644 index acbc19196..000000000 --- a/changelog.d/4999.bugfix +++ /dev/null @@ -1 +0,0 @@ -Prevent the ability to kick users from a room they aren't in. diff --git a/changelog.d/5001.misc b/changelog.d/5001.misc deleted file mode 100644 index bf590a016..000000000 --- a/changelog.d/5001.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some code in the server-key Keyring. \ No newline at end of file diff --git a/changelog.d/5002.feature b/changelog.d/5002.feature deleted file mode 100644 index d8f50e963..000000000 --- a/changelog.d/5002.feature +++ /dev/null @@ -1 +0,0 @@ -Add a delete group admin API. diff --git a/changelog.d/5003.bugfix b/changelog.d/5003.bugfix deleted file mode 100644 index 9955dc871..000000000 --- a/changelog.d/5003.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix issue #4596 so synapse_port_db script works with --curses option on Python 3. Contributed by Anders Jensen-Waud . diff --git a/changelog.d/5007.misc b/changelog.d/5007.misc deleted file mode 100644 index 05b6ce2c2..000000000 --- a/changelog.d/5007.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor synapse.storage._base._simple_select_list_paginate. \ No newline at end of file diff --git a/changelog.d/5010.feature b/changelog.d/5010.feature deleted file mode 100644 index 65ab198b7..000000000 --- a/changelog.d/5010.feature +++ /dev/null @@ -1 +0,0 @@ -Add config option to block users from looking up 3PIDs. diff --git a/changelog.d/5020.feature b/changelog.d/5020.feature deleted file mode 100644 index 71f7a8db2..000000000 --- a/changelog.d/5020.feature +++ /dev/null @@ -1 +0,0 @@ -Add context to phonehome stats. diff --git a/changelog.d/5024.misc b/changelog.d/5024.misc deleted file mode 100644 index 07c13f28d..000000000 --- a/changelog.d/5024.misc +++ /dev/null @@ -1 +0,0 @@ -Store the notary server name correctly in server_keys_json. diff --git a/changelog.d/5027.feature b/changelog.d/5027.feature deleted file mode 100644 index 12766a82a..000000000 --- a/changelog.d/5027.feature +++ /dev/null @@ -1 +0,0 @@ -Add time-based account expiration. diff --git a/changelog.d/5028.misc b/changelog.d/5028.misc deleted file mode 100644 index 3ee4228c0..000000000 --- a/changelog.d/5028.misc +++ /dev/null @@ -1 +0,0 @@ -Remove a number of unused tables from the database schema. diff --git a/changelog.d/5030.misc b/changelog.d/5030.misc deleted file mode 100644 index 3456eb538..000000000 --- a/changelog.d/5030.misc +++ /dev/null @@ -1 +0,0 @@ -Rewrite Datastore.get_server_verify_keys to reduce the number of database transactions. diff --git a/changelog.d/5032.bugfix b/changelog.d/5032.bugfix deleted file mode 100644 index cd71180ce..000000000 --- a/changelog.d/5032.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix "cannot import name execute_batch" error with postgres. diff --git a/changelog.d/5033.misc b/changelog.d/5033.misc deleted file mode 100644 index 3ee4228c0..000000000 --- a/changelog.d/5033.misc +++ /dev/null @@ -1 +0,0 @@ -Remove a number of unused tables from the database schema. diff --git a/changelog.d/5035.bugfix b/changelog.d/5035.bugfix deleted file mode 100644 index 85e154027..000000000 --- a/changelog.d/5035.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix disappearing exceptions in manhole. diff --git a/changelog.d/5046.misc b/changelog.d/5046.misc deleted file mode 100644 index eb966a5ae..000000000 --- a/changelog.d/5046.misc +++ /dev/null @@ -1 +0,0 @@ -Remove extraneous period from copyright headers. diff --git a/changelog.d/5047.feature b/changelog.d/5047.feature deleted file mode 100644 index 12766a82a..000000000 --- a/changelog.d/5047.feature +++ /dev/null @@ -1 +0,0 @@ -Add time-based account expiration. diff --git a/changelog.d/5063.feature b/changelog.d/5063.feature deleted file mode 100644 index fd7b80018..000000000 --- a/changelog.d/5063.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for handling /verions, /voip and /push_rules client endpoints to client_reader worker. diff --git a/changelog.d/5065.feature b/changelog.d/5065.feature deleted file mode 100644 index fd7b80018..000000000 --- a/changelog.d/5065.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for handling /verions, /voip and /push_rules client endpoints to client_reader worker. diff --git a/changelog.d/5067.misc b/changelog.d/5067.misc deleted file mode 100644 index bbb4337db..000000000 --- a/changelog.d/5067.misc +++ /dev/null @@ -1 +0,0 @@ -Update documentation for where to get Synapse packages. diff --git a/changelog.d/5070.feature b/changelog.d/5070.feature deleted file mode 100644 index fd7b80018..000000000 --- a/changelog.d/5070.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for handling /verions, /voip and /push_rules client endpoints to client_reader worker. diff --git a/changelog.d/5071.bugfix b/changelog.d/5071.bugfix deleted file mode 100644 index ddf7ab5fa..000000000 --- a/changelog.d/5071.bugfix +++ /dev/null @@ -1 +0,0 @@ -Make sure we're not registering the same 3pid twice on registration. diff --git a/changelog.d/5073.feature b/changelog.d/5073.feature deleted file mode 100644 index 12766a82a..000000000 --- a/changelog.d/5073.feature +++ /dev/null @@ -1 +0,0 @@ -Add time-based account expiration. diff --git a/changelog.d/5077.bugfix b/changelog.d/5077.bugfix deleted file mode 100644 index f3345a635..000000000 --- a/changelog.d/5077.bugfix +++ /dev/null @@ -1 +0,0 @@ -Don't crash on lack of expiry templates. diff --git a/changelog.d/5200.bugfix b/changelog.d/5200.bugfix new file mode 100644 index 000000000..f346c7b0c --- /dev/null +++ b/changelog.d/5200.bugfix @@ -0,0 +1 @@ +Fix worker registration bug caused by ClientReaderSlavedStore being unable to see get_profileinfo. diff --git a/changelog.d/5216.misc b/changelog.d/5216.misc new file mode 100644 index 000000000..dbfa29475 --- /dev/null +++ b/changelog.d/5216.misc @@ -0,0 +1 @@ +Synapse will now serve the experimental "room complexity" API endpoint. diff --git a/changelog.d/5220.feature b/changelog.d/5220.feature new file mode 100644 index 000000000..747098c16 --- /dev/null +++ b/changelog.d/5220.feature @@ -0,0 +1 @@ +Add experimental support for relations (aka reactions and edits). diff --git a/changelog.d/5223.feature b/changelog.d/5223.feature new file mode 100644 index 000000000..cfdf1ad41 --- /dev/null +++ b/changelog.d/5223.feature @@ -0,0 +1 @@ +Ability to configure default room version. diff --git a/changelog.d/5226.misc b/changelog.d/5226.misc new file mode 100644 index 000000000..e1b9dc58a --- /dev/null +++ b/changelog.d/5226.misc @@ -0,0 +1 @@ +The base classes for the v1 and v2_alpha REST APIs have been unified. diff --git a/changelog.d/5227.misc b/changelog.d/5227.misc new file mode 100644 index 000000000..32bd7b600 --- /dev/null +++ b/changelog.d/5227.misc @@ -0,0 +1 @@ +Simplifications and comments in do_auth. diff --git a/changelog.d/5230.misc b/changelog.d/5230.misc new file mode 100644 index 000000000..c681bc974 --- /dev/null +++ b/changelog.d/5230.misc @@ -0,0 +1 @@ +Remove urllib3 pin as requests 2.22.0 has been released supporting urllib3 1.25.2. diff --git a/changelog.d/5232.misc b/changelog.d/5232.misc new file mode 100644 index 000000000..1cdc71f09 --- /dev/null +++ b/changelog.d/5232.misc @@ -0,0 +1 @@ +Run black on synapse.crypto.keyring. diff --git a/changelog.d/5233.bugfix b/changelog.d/5233.bugfix new file mode 100644 index 000000000..d71b96216 --- /dev/null +++ b/changelog.d/5233.bugfix @@ -0,0 +1 @@ +Fix appservice timestamp massaging. diff --git a/changelog.d/5234.misc b/changelog.d/5234.misc new file mode 100644 index 000000000..43fbd6f67 --- /dev/null +++ b/changelog.d/5234.misc @@ -0,0 +1 @@ +Rewrite store_server_verify_key to store several keys at once. diff --git a/changelog.d/5235.misc b/changelog.d/5235.misc new file mode 100644 index 000000000..2296ad2a4 --- /dev/null +++ b/changelog.d/5235.misc @@ -0,0 +1 @@ +Remove unused VerifyKey.expired and .time_added fields. diff --git a/changelog.d/5236.misc b/changelog.d/5236.misc new file mode 100644 index 000000000..cb4417a9f --- /dev/null +++ b/changelog.d/5236.misc @@ -0,0 +1 @@ +Simplify Keyring.process_v2_response. \ No newline at end of file diff --git a/changelog.d/5237.misc b/changelog.d/5237.misc new file mode 100644 index 000000000..f4fe3b821 --- /dev/null +++ b/changelog.d/5237.misc @@ -0,0 +1 @@ +Store key validity time in the storage layer. diff --git a/changelog.d/5244.misc b/changelog.d/5244.misc new file mode 100644 index 000000000..9cc1fb869 --- /dev/null +++ b/changelog.d/5244.misc @@ -0,0 +1 @@ +Refactor synapse.crypto.keyring to use a KeyFetcher interface. diff --git a/changelog.d/5249.feature b/changelog.d/5249.feature new file mode 100644 index 000000000..cfdf1ad41 --- /dev/null +++ b/changelog.d/5249.feature @@ -0,0 +1 @@ +Ability to configure default room version. diff --git a/changelog.d/5250.misc b/changelog.d/5250.misc new file mode 100644 index 000000000..575a299a8 --- /dev/null +++ b/changelog.d/5250.misc @@ -0,0 +1 @@ +Simplification to Keyring.wait_for_previous_lookups. diff --git a/changelog.d/5251.bugfix b/changelog.d/5251.bugfix new file mode 100644 index 000000000..9a053204b --- /dev/null +++ b/changelog.d/5251.bugfix @@ -0,0 +1 @@ +Ensure that server_keys fetched via a notary server are correctly signed. \ No newline at end of file diff --git a/changelog.d/5256.bugfix b/changelog.d/5256.bugfix new file mode 100644 index 000000000..86316ab5d --- /dev/null +++ b/changelog.d/5256.bugfix @@ -0,0 +1 @@ +Show the correct error when logging out and access token is missing. diff --git a/changelog.d/5257.bugfix b/changelog.d/5257.bugfix new file mode 100644 index 000000000..8334af9b9 --- /dev/null +++ b/changelog.d/5257.bugfix @@ -0,0 +1 @@ +Fix error code when there is an invalid parameter on /_matrix/client/r0/publicRooms diff --git a/changelog.d/5258.bugfix b/changelog.d/5258.bugfix new file mode 100644 index 000000000..fb5d44aed --- /dev/null +++ b/changelog.d/5258.bugfix @@ -0,0 +1 @@ +Fix error when downloading thumbnail with missing width/height parameter. diff --git a/changelog.d/5260.feature b/changelog.d/5260.feature new file mode 100644 index 000000000..01285e965 --- /dev/null +++ b/changelog.d/5260.feature @@ -0,0 +1 @@ +Synapse now more efficiently collates room statistics. diff --git a/changelog.d/5268.bugfix b/changelog.d/5268.bugfix new file mode 100644 index 000000000..1a5a03bf0 --- /dev/null +++ b/changelog.d/5268.bugfix @@ -0,0 +1 @@ +Fix schema update for account validity. diff --git a/changelog.d/5274.bugfix b/changelog.d/5274.bugfix new file mode 100644 index 000000000..9e14d2028 --- /dev/null +++ b/changelog.d/5274.bugfix @@ -0,0 +1 @@ +Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. diff --git a/changelog.d/5275.bugfix b/changelog.d/5275.bugfix new file mode 100644 index 000000000..45a554642 --- /dev/null +++ b/changelog.d/5275.bugfix @@ -0,0 +1 @@ +Fix "db txn 'update_presence' from sentinel context" log messages. diff --git a/changelog.d/5276.feature b/changelog.d/5276.feature new file mode 100644 index 000000000..403dee086 --- /dev/null +++ b/changelog.d/5276.feature @@ -0,0 +1 @@ +Allow configuring a range for the account validity startup job. diff --git a/changelog.d/5277.bugfix b/changelog.d/5277.bugfix new file mode 100644 index 000000000..371aa2e7f --- /dev/null +++ b/changelog.d/5277.bugfix @@ -0,0 +1 @@ +Fix dropped logcontexts during high outbound traffic. diff --git a/changelog.d/5278.bugfix b/changelog.d/5278.bugfix new file mode 100644 index 000000000..9e14d2028 --- /dev/null +++ b/changelog.d/5278.bugfix @@ -0,0 +1 @@ +Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. diff --git a/changelog.d/5282.doc b/changelog.d/5282.doc new file mode 100644 index 000000000..350e15bc0 --- /dev/null +++ b/changelog.d/5282.doc @@ -0,0 +1 @@ +Fix docs on resetting the user directory. diff --git a/changelog.d/5283.misc b/changelog.d/5283.misc new file mode 100644 index 000000000..002721e56 --- /dev/null +++ b/changelog.d/5283.misc @@ -0,0 +1 @@ +Specify the type of reCAPTCHA key to use. diff --git a/changelog.d/5286.feature b/changelog.d/5286.feature new file mode 100644 index 000000000..81860279a --- /dev/null +++ b/changelog.d/5286.feature @@ -0,0 +1 @@ +CAS login will now hit the r0 API, not the deprecated v1 one. diff --git a/changelog.d/5287.misc b/changelog.d/5287.misc new file mode 100644 index 000000000..1286f1dd0 --- /dev/null +++ b/changelog.d/5287.misc @@ -0,0 +1 @@ +Remove spurious debug from MatrixFederationHttpClient.get_json. diff --git a/changelog.d/5288.misc b/changelog.d/5288.misc new file mode 100644 index 000000000..fbf049ba6 --- /dev/null +++ b/changelog.d/5288.misc @@ -0,0 +1 @@ +Improve logging for logcontext leaks. diff --git a/changelog.d/5291.bugfix b/changelog.d/5291.bugfix new file mode 100644 index 000000000..9e14d2028 --- /dev/null +++ b/changelog.d/5291.bugfix @@ -0,0 +1 @@ +Fix bug where we leaked extremities when we soft failed events, leading to performance degradation. diff --git a/changelog.d/5293.bugfix b/changelog.d/5293.bugfix new file mode 100644 index 000000000..aa519a843 --- /dev/null +++ b/changelog.d/5293.bugfix @@ -0,0 +1 @@ +Fix a bug where it is not possible to get events in the federation format with the request `GET /_matrix/client/r0/rooms/{roomId}/messages`. diff --git a/changelog.d/5294.bugfix b/changelog.d/5294.bugfix new file mode 100644 index 000000000..5924bda31 --- /dev/null +++ b/changelog.d/5294.bugfix @@ -0,0 +1 @@ +Fix performance problems with the rooms stats background update. diff --git a/changelog.d/5296.misc b/changelog.d/5296.misc new file mode 100644 index 000000000..a038a6f7f --- /dev/null +++ b/changelog.d/5296.misc @@ -0,0 +1 @@ +Refactor keyring.VerifyKeyRequest to use attr.s. diff --git a/changelog.d/5299.misc b/changelog.d/5299.misc new file mode 100644 index 000000000..53297c768 --- /dev/null +++ b/changelog.d/5299.misc @@ -0,0 +1 @@ +Rewrite get_server_verify_keys, again. diff --git a/changelog.d/5300.bugfix b/changelog.d/5300.bugfix new file mode 100644 index 000000000..049e93cd5 --- /dev/null +++ b/changelog.d/5300.bugfix @@ -0,0 +1 @@ +Fix noisy 'no key for server' logs. diff --git a/changelog.d/5303.misc b/changelog.d/5303.misc new file mode 100644 index 000000000..f6a7f1f8e --- /dev/null +++ b/changelog.d/5303.misc @@ -0,0 +1 @@ +Clarify that the admin change password API logs the user out. diff --git a/changelog.d/5307.bugfix b/changelog.d/5307.bugfix new file mode 100644 index 000000000..6b152f485 --- /dev/null +++ b/changelog.d/5307.bugfix @@ -0,0 +1 @@ +Fix bug where a notary server would sometimes forget old keys. diff --git a/changelog.d/5309.bugfix b/changelog.d/5309.bugfix new file mode 100644 index 000000000..97b352726 --- /dev/null +++ b/changelog.d/5309.bugfix @@ -0,0 +1 @@ +Prevent users from setting huge displaynames and avatar URLs. diff --git a/changelog.d/5321.bugfix b/changelog.d/5321.bugfix new file mode 100644 index 000000000..943a61956 --- /dev/null +++ b/changelog.d/5321.bugfix @@ -0,0 +1 @@ +Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests. diff --git a/changelog.d/5324.feature b/changelog.d/5324.feature new file mode 100644 index 000000000..01285e965 --- /dev/null +++ b/changelog.d/5324.feature @@ -0,0 +1 @@ +Synapse now more efficiently collates room statistics. diff --git a/changelog.d/5328.misc b/changelog.d/5328.misc new file mode 100644 index 000000000..e1b9dc58a --- /dev/null +++ b/changelog.d/5328.misc @@ -0,0 +1 @@ +The base classes for the v1 and v2_alpha REST APIs have been unified. diff --git a/changelog.d/5332.misc b/changelog.d/5332.misc new file mode 100644 index 000000000..dcfac4eac --- /dev/null +++ b/changelog.d/5332.misc @@ -0,0 +1 @@ +Improve docstrings on MatrixFederationClient. diff --git a/changelog.d/5333.bugfix b/changelog.d/5333.bugfix new file mode 100644 index 000000000..cb05a6dd6 --- /dev/null +++ b/changelog.d/5333.bugfix @@ -0,0 +1 @@ +Fix various problems which made the signing-key notary server time out for some requests. \ No newline at end of file diff --git a/changelog.d/5334.bugfix b/changelog.d/5334.bugfix new file mode 100644 index 000000000..ed141e091 --- /dev/null +++ b/changelog.d/5334.bugfix @@ -0,0 +1 @@ +Fix bug which would make certain operations (such as room joins) block for 20 minutes while attemoting to fetch verification keys. diff --git a/changelog.d/5335.bugfix b/changelog.d/5335.bugfix new file mode 100644 index 000000000..7318cbe35 --- /dev/null +++ b/changelog.d/5335.bugfix @@ -0,0 +1 @@ +Fix a bug where we could rapidly mark a server as unreachable even though it was only down for a few minutes. diff --git a/changelog.d/5340.bugfix b/changelog.d/5340.bugfix new file mode 100644 index 000000000..931ee904e --- /dev/null +++ b/changelog.d/5340.bugfix @@ -0,0 +1,2 @@ +Fix a bug where we could rapidly mark a server as unreachable even though it was only down for a few minutes. + diff --git a/changelog.d/5341.bugfix b/changelog.d/5341.bugfix new file mode 100644 index 000000000..a7aaa95f3 --- /dev/null +++ b/changelog.d/5341.bugfix @@ -0,0 +1 @@ +Fix a bug where account validity renewal emails could only be sent when email notifs were enabled. diff --git a/changelog.d/5342.bugfix b/changelog.d/5342.bugfix new file mode 100644 index 000000000..66a307629 --- /dev/null +++ b/changelog.d/5342.bugfix @@ -0,0 +1 @@ +Fix failure when fetching batches of events during backfill, etc. diff --git a/changelog.d/5343.misc b/changelog.d/5343.misc new file mode 100644 index 000000000..dbee0f71b --- /dev/null +++ b/changelog.d/5343.misc @@ -0,0 +1 @@ +Rename VerifyKeyRequest.deferred field. diff --git a/changelog.d/5344.misc b/changelog.d/5344.misc new file mode 100644 index 000000000..a20c563bf --- /dev/null +++ b/changelog.d/5344.misc @@ -0,0 +1 @@ +Clean up FederationClient.get_events for clarity. diff --git a/changelog.d/5347.misc b/changelog.d/5347.misc new file mode 100644 index 000000000..436245fb1 --- /dev/null +++ b/changelog.d/5347.misc @@ -0,0 +1,2 @@ +Various improvements to debug logging. + diff --git a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service index 912984b9d..9d980d516 100644 --- a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service +++ b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service @@ -12,6 +12,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.%i --config-path=/ ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse-%i [Install] WantedBy=matrix-synapse.service diff --git a/contrib/systemd-with-workers/system/matrix-synapse.service b/contrib/systemd-with-workers/system/matrix-synapse.service index 8bb4e400d..3aae19034 100644 --- a/contrib/systemd-with-workers/system/matrix-synapse.service +++ b/contrib/systemd-with-workers/system/matrix-synapse.service @@ -11,6 +11,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.homeserver --confi ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse [Install] WantedBy=matrix.target diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service index efb157e94..595b69916 100644 --- a/contrib/systemd/matrix-synapse.service +++ b/contrib/systemd/matrix-synapse.service @@ -22,10 +22,10 @@ Group=nogroup WorkingDirectory=/opt/synapse ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml +SyslogIdentifier=matrix-synapse # adjust the cache factor if necessary # Environment=SYNAPSE_CACHE_FACTOR=2.0 [Install] WantedBy=multi-user.target - diff --git a/debian/changelog b/debian/changelog index 03df2e1c0..6a1a72c0e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,37 @@ +matrix-synapse-py3 (0.99.5.2) stable; urgency=medium + + * New synapse release 0.99.5.2. + + -- Synapse Packaging team Thu, 30 May 2019 16:28:07 +0100 + +matrix-synapse-py3 (0.99.5.1) stable; urgency=medium + + * New synapse release 0.99.5.1. + + -- Synapse Packaging team Wed, 22 May 2019 16:22:24 +0000 + +matrix-synapse-py3 (0.99.4) stable; urgency=medium + + [ Christoph Müller ] + * Configure the systemd units to have a log identifier of `matrix-synapse` + + [ Synapse Packaging team ] + * New synapse release 0.99.4. + + -- Synapse Packaging team Wed, 15 May 2019 13:58:08 +0100 + +matrix-synapse-py3 (0.99.3.2) stable; urgency=medium + + * New synapse release 0.99.3.2. + + -- Synapse Packaging team Fri, 03 May 2019 18:56:20 +0100 + +matrix-synapse-py3 (0.99.3.1) stable; urgency=medium + + * New synapse release 0.99.3.1. + + -- Synapse Packaging team Fri, 03 May 2019 16:02:43 +0100 + matrix-synapse-py3 (0.99.3) stable; urgency=medium [ Richard van der Hoff ] diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service index 942e4b83f..b0a8d72e6 100644 --- a/debian/matrix-synapse.service +++ b/debian/matrix-synapse.service @@ -11,6 +11,7 @@ ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.homeserver --confi ExecReload=/bin/kill -HUP $MAINPID Restart=always RestartSec=3 +SyslogIdentifier=matrix-synapse [Install] WantedBy=multi-user.target diff --git a/debian/test/.gitignore b/debian/test/.gitignore new file mode 100644 index 000000000..95eda73fc --- /dev/null +++ b/debian/test/.gitignore @@ -0,0 +1,2 @@ +.vagrant +*.log diff --git a/debian/test/provision.sh b/debian/test/provision.sh new file mode 100644 index 000000000..a5c7f5971 --- /dev/null +++ b/debian/test/provision.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# +# provisioning script for vagrant boxes for testing the matrix-synapse debs. +# +# Will install the most recent matrix-synapse-py3 deb for this platform from +# the /debs directory. + +set -e + +apt-get update +apt-get install -y lsb-release + +deb=`ls /debs/matrix-synapse-py3_*+$(lsb_release -cs)*.deb | sort | tail -n1` + +debconf-set-selections < +POST /_synapse/admin/v1/delete_group/ ``` including an `access_token` of a server admin. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index abdbc1ea8..5e9f8e5d8 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,7 +4,7 @@ This API gets a list of known media in a room. The API is: ``` -GET /_matrix/client/r0/admin/room//media +GET /_synapse/admin/v1/room//media ``` including an `access_token` of a server admin. diff --git a/docs/admin_api/purge_history_api.rst b/docs/admin_api/purge_history_api.rst index a5c3dc814..f7be226fd 100644 --- a/docs/admin_api/purge_history_api.rst +++ b/docs/admin_api/purge_history_api.rst @@ -10,7 +10,7 @@ paginate further back in the room from the point being purged from. The API is: -``POST /_matrix/client/r0/admin/purge_history/[/]`` +``POST /_synapse/admin/v1/purge_history/[/]`` including an ``access_token`` of a server admin. @@ -49,7 +49,7 @@ Purge status query It is possible to poll for updates on recent purges with a second API; -``GET /_matrix/client/r0/admin/purge_history_status/`` +``GET /_synapse/admin/v1/purge_history_status/`` (again, with a suitable ``access_token``). This API returns a JSON body like the following: diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst index 5deb02a3d..dacd5bc8f 100644 --- a/docs/admin_api/purge_remote_media.rst +++ b/docs/admin_api/purge_remote_media.rst @@ -6,7 +6,7 @@ media. The API is:: - POST /_matrix/client/r0/admin/purge_media_cache?before_ts=&access_token= + POST /_synapse/admin/v1/purge_media_cache?before_ts=&access_token= {} diff --git a/docs/admin_api/register_api.rst b/docs/admin_api/register_api.rst index 084e74ebf..3a63109aa 100644 --- a/docs/admin_api/register_api.rst +++ b/docs/admin_api/register_api.rst @@ -12,7 +12,7 @@ is not enabled. To fetch the nonce, you need to request one from the API:: - > GET /_matrix/client/r0/admin/register + > GET /_synapse/admin/v1/register < {"nonce": "thisisanonce"} @@ -22,7 +22,7 @@ body containing the nonce, username, password, whether they are an admin As an example:: - > POST /_matrix/client/r0/admin/register + > POST /_synapse/admin/v1/register > { "nonce": "thisisanonce", "username": "pepper_roni", diff --git a/docs/admin_api/server_notices.md b/docs/admin_api/server_notices.md new file mode 100644 index 000000000..858b052b8 --- /dev/null +++ b/docs/admin_api/server_notices.md @@ -0,0 +1,48 @@ +# Server Notices + +The API to send notices is as follows: + +``` +POST /_synapse/admin/v1/send_server_notice +``` + +or: + +``` +PUT /_synapse/admin/v1/send_server_notice/{txnId} +``` + +You will need to authenticate with an access token for an admin user. + +When using the `PUT` form, retransmissions with the same transaction ID will be +ignored in the same way as with `PUT +/_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}`. + +The request body should look something like the following: + +```json +{ + "user_id": "@target_user:server_name", + "content": { + "msgtype": "m.text", + "body": "This is my message" + } +} +``` + +You can optionally include the following additional parameters: + +* `type`: the type of event. Defaults to `m.room.message`. +* `state_key`: Setting this will result in a state event being sent. + + +Once the notice has been sent, the API will return the following response: + +```json +{ + "event_id": "" +} +``` + +Note that server notices must be enabled in `homeserver.yaml` before this API +can be used. See [server_notices.md](../server_notices.md) for more information. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d17121a18..213359d0c 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -5,7 +5,7 @@ This API returns information about a specific user account. The api is:: - GET /_matrix/client/r0/admin/whois/ + GET /_synapse/admin/v1/whois/ including an ``access_token`` of a server admin. @@ -50,7 +50,7 @@ references to it). The api is:: - POST /_matrix/client/r0/admin/deactivate/ + POST /_synapse/admin/v1/deactivate/ with a body of: @@ -69,11 +69,11 @@ An empty body may be passed for backwards compatibility. Reset password ============== -Changes the password of another user. +Changes the password of another user. This will automatically log the user out of all their devices. The api is:: - POST /_matrix/client/r0/admin/reset_password/ + POST /_synapse/admin/v1/reset_password/ with a body of: diff --git a/docs/admin_api/version_api.rst b/docs/admin_api/version_api.rst index 30a91b5f4..833d9028b 100644 --- a/docs/admin_api/version_api.rst +++ b/docs/admin_api/version_api.rst @@ -8,9 +8,7 @@ contains Synapse version information). The api is:: - GET /_matrix/client/r0/admin/server_version - -including an ``access_token`` of a server admin. + GET /_synapse/admin/v1/server_version It returns a JSON body like the following: diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst index 5bbb5a4f3..32b064e2d 100644 --- a/docs/metrics-howto.rst +++ b/docs/metrics-howto.rst @@ -48,7 +48,10 @@ How to monitor Synapse metrics using Prometheus - job_name: "synapse" metrics_path: "/_synapse/metrics" static_configs: - - targets: ["my.server.here:9092"] + - targets: ["my.server.here:port"] + + where ``my.server.here`` is the IP address of Synapse, and ``port`` is the listener port + configured with the ``metrics`` resource. If your prometheus is older than 1.5.2, you will need to replace ``static_configs`` in the above with ``target_groups``. diff --git a/docs/postgres.rst b/docs/postgres.rst index f7ebbed0c..e81e10403 100644 --- a/docs/postgres.rst +++ b/docs/postgres.rst @@ -3,6 +3,28 @@ Using Postgres Postgres version 9.4 or later is known to work. +Install postgres client libraries +================================= + +Synapse will require the python postgres client library in order to connect to +a postgres database. + +* If you are using the `matrix.org debian/ubuntu + packages <../INSTALL.md#matrixorg-packages>`_, + the necessary libraries will already be installed. + +* For other pre-built packages, please consult the documentation from the + relevant package. + +* If you installed synapse `in a virtualenv + <../INSTALL.md#installing-from-source>`_, you can install the library with:: + + ~/synapse/env/bin/pip install matrix-synapse[postgres] + + (substituting the path to your virtualenv for ``~/synapse/env``, if you used a + different path). You will require the postgres development files. These are in + the ``libpq-dev`` package on Debian-derived distributions. + Set up database =============== @@ -26,29 +48,6 @@ encoding use, e.g.:: This would create an appropriate database named ``synapse`` owned by the ``synapse_user`` user (which must already exist). -Set up client in Debian/Ubuntu -=========================== - -Postgres support depends on the postgres python connector ``psycopg2``. In the -virtual env:: - - sudo apt-get install libpq-dev - pip install psycopg2 - -Set up client in RHEL/CentOs 7 -============================== - -Make sure you have the appropriate version of postgres-devel installed. For a -postgres 9.4, use the postgres 9.4 packages from -[here](https://wiki.postgresql.org/wiki/YUM_Installation). - -As with Debian/Ubuntu, postgres support depends on the postgres python connector -``psycopg2``. In the virtual env:: - - sudo yum install postgresql-devel libpqxx-devel.x86_64 - export PATH=/usr/pgsql-9.4/bin/:$PATH - pip install psycopg2 - Tuning Postgres =============== diff --git a/docs/reverse_proxy.rst b/docs/reverse_proxy.rst index cc81ceb84..7619b1097 100644 --- a/docs/reverse_proxy.rst +++ b/docs/reverse_proxy.rst @@ -69,6 +69,7 @@ Let's assume that we expect clients to connect to our server at SSLEngine on ServerName matrix.example.com; + AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix @@ -77,6 +78,7 @@ Let's assume that we expect clients to connect to our server at SSLEngine on ServerName example.com; + AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ab02e8f20..493ea9ee9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -69,6 +69,30 @@ pid_file: DATADIR/homeserver.pid # #use_presence: false +# Whether to require authentication to retrieve profile data (avatars, +# display names) of other users through the client API. Defaults to +# 'false'. Note that profile data is also available via the federation +# API, so this setting is of limited value if federation is enabled on +# the server. +# +#require_auth_for_profile_requests: true + +# If set to 'true', requires authentication to access the server's +# public rooms directory through the client API, and forbids any other +# homeserver to fetch it via federation. Defaults to 'false'. +# +#restrict_public_rooms_to_local_users: true + +# The default room version for newly created rooms. +# +# Known room versions are listed here: +# https://matrix.org/docs/spec/#complete-list-of-room-versions +# +# For example, for room version 1, default_room_version should be set +# to "1". +# +#default_room_version: "1" + # The GC threshold parameters to pass to `gc.set_threshold`, if defined # #gc_thresholds: [700, 10, 10] @@ -101,6 +125,24 @@ pid_file: DATADIR/homeserver.pid # - nyc.example.com # - syd.example.com +# Prevent federation requests from being sent to the following +# blacklist IP address CIDR ranges. If this option is not specified, or +# specified with an empty list, no ip range blacklist will be enforced. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -136,8 +178,8 @@ pid_file: DATADIR/homeserver.pid # # Valid resource names are: # -# client: the client-server API (/_matrix/client). Also implies 'media' and -# 'static'. +# client: the client-server API (/_matrix/client), and the synapse admin +# API (/_synapse/admin). Also implies 'media' and 'static'. # # consent: user consent forms (/_matrix/consent). See # docs/consent_tracking.md. @@ -239,6 +281,17 @@ listeners: # Used by phonehome stats to group together related servers. #server_context: context +# Whether to require a user to be in the room to add an alias to it. +# Defaults to 'true'. +# +#require_membership_for_aliases: false + +# Whether to allow per-room membership profiles through the send of membership +# events with profile information that differ from the target's global profile. +# Defaults to 'true'. +# +#allow_per_room_profiles: false + ## TLS ## @@ -260,6 +313,40 @@ listeners: # #tls_private_key_path: "CONFDIR/SERVERNAME.tls.key" +# Whether to verify TLS certificates when sending federation traffic. +# +# This currently defaults to `false`, however this will change in +# Synapse 1.0 when valid federation certificates will be required. +# +#federation_verify_certificates: true + +# Skip federation certificate verification on the following whitelist +# of domains. +# +# This setting should only be used in very specific cases, such as +# federation over Tor hidden services and similar. For private networks +# of homeservers, you likely want to use a private CA instead. +# +# Only effective if federation_verify_certicates is `true`. +# +#federation_certificate_verification_whitelist: +# - lon.example.com +# - *.domain.com +# - *.onion + +# List of custom certificate authorities for federation traffic. +# +# This setting should only normally be used within a private network of +# homeservers. +# +# Note that this list will replace those that are provided by your +# operating environment. Certificates must be in PEM format. +# +#federation_custom_ca_list: +# - myCA1.pem +# - myCA2.pem +# - myCA3.pem + # ACME support: This will configure Synapse to request a valid TLS certificate # for your configured `server_name` via Let's Encrypt. # @@ -375,21 +462,15 @@ log_config: "CONFDIR/SERVERNAME.log.config" ## Ratelimiting ## -# Number of messages a client can send per second -# -#rc_messages_per_second: 0.2 - -# Number of message a client can send before being throttled -# -#rc_message_burst_count: 10.0 - -# Ratelimiting settings for registration and login. +# Ratelimiting settings for client actions (registration, login, messaging). # # Each ratelimiting configuration is made of two parameters: # - per_second: number of requests a client can send per second. # - burst_count: number of requests a client can send before being throttled. # # Synapse currently uses the following configurations: +# - one for messages that ratelimits sending based on the account the client +# is using # - one for registration that ratelimits registration requests based on the # client's IP address. # - one for login that ratelimits login requests based on the client's IP @@ -402,6 +483,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # # The defaults are as shown below. # +#rc_message: +# per_second: 0.2 +# burst_count: 10 +# #rc_registration: # per_second: 0.17 # burst_count: 3 @@ -417,29 +502,28 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_second: 0.17 # burst_count: 3 -# The federation window size in milliseconds -# -#federation_rc_window_size: 1000 -# The number of federation requests from a single server in a window -# before the server will delay processing the request. +# Ratelimiting settings for incoming federation # -#federation_rc_sleep_limit: 10 - -# The duration in milliseconds to delay processing events from -# remote servers by if they go over the sleep limit. +# The rc_federation configuration is made up of the following settings: +# - window_size: window size in milliseconds +# - sleep_limit: number of federation requests from a single server in +# a window before the server will delay processing the request. +# - sleep_delay: duration in milliseconds to delay processing events +# from remote servers by if they go over the sleep limit. +# - reject_limit: maximum number of concurrent federation requests +# allowed from a single server +# - concurrent: number of federation requests to concurrently process +# from a single server # -#federation_rc_sleep_delay: 500 - -# The maximum number of concurrent federation requests allowed -# from a single server +# The defaults are as shown below. # -#federation_rc_reject_limit: 50 - -# The number of federation requests to concurrently process from a -# single server -# -#federation_rc_concurrent: 3 +#rc_federation: +# window_size: 1000 +# sleep_limit: 10 +# sleep_delay: 500 +# reject_limit: 50 +# concurrent: 3 # Target outgoing federation transaction frequency for sending read-receipts, # per-room. @@ -509,11 +593,12 @@ uploads_path: "DATADIR/uploads" # height: 600 # method: scale -# Is the preview URL API enabled? If enabled, you *must* specify -# an explicit url_preview_ip_range_blacklist of IPs that the spider is -# denied from accessing. +# Is the preview URL API enabled? # -#url_preview_enabled: false +# 'false' by default: uncomment the following to enable it (and specify a +# url_preview_ip_range_blacklist blacklist). +# +#url_preview_enabled: true # List of IP address CIDR ranges that the URL preview spider is denied # from accessing. There are no defaults: you must explicitly @@ -523,6 +608,12 @@ uploads_path: "DATADIR/uploads" # synapse to issue arbitrary GET requests to your internal services, # causing serious security issues. # +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +# This must be specified if url_preview_enabled is set. It is recommended that +# you uncomment the following list as a starting point. +# #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -533,7 +624,7 @@ uploads_path: "DATADIR/uploads" # - '::1/128' # - 'fe80::/64' # - 'fc00::/7' -# + # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. # This is useful for specifying exceptions to wide-ranging blacklisted @@ -666,6 +757,16 @@ uploads_path: "DATADIR/uploads" # link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter # from the ``email`` section. # +# Once this feature is enabled, Synapse will look for registered users without an +# expiration date at startup and will add one to every account it found using the +# current settings at that time. +# This means that, if a validity period is set, and Synapse is restarted (it will +# then derive an expiration date from the current validity period), and some time +# after that the validity period changes and Synapse is restarted, the users' +# expiration dates won't be updated unless their account is manually renewed. This +# date will be randomly selected within a range [now + period - d ; now + period], +# where d is equal to 10% of the validity period. +# #account_validity: # enabled: True # period: 6w @@ -1004,9 +1105,9 @@ password_config: # # 'search_all_users' defines whether to search all users visible to your HS # when searching the user directory, rather than limiting to users visible -# in public rooms. Defaults to false. If you set it True, you'll have to run -# UPDATE user_directory_stream_pos SET stream_id = NULL; -# on your database to tell it to rebuild the user_directory search indexes. +# in public rooms. Defaults to false. If you set it True, you'll have to +# rebuild the user_directory search indexes, see +# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # #user_directory: # enabled: true @@ -1064,6 +1165,22 @@ password_config: # + +# Local statistics collection. Used in populating the room directory. +# +# 'bucket_size' controls how large each statistics timeslice is. It can +# be defined in a human readable short form -- e.g. "1d", "1y". +# +# 'retention' controls how long historical statistics will be kept for. +# It can be defined in a human readable short form -- e.g. "1d", "1y". +# +# +#stats: +# enabled: true +# bucket_size: 1d +# retention: 1y + + # Server Notices room configuration # # Uncomment this section to enable a room which can be used to send notices diff --git a/docs/server_notices.md b/docs/server_notices.md index 58f877631..950a6608e 100644 --- a/docs/server_notices.md +++ b/docs/server_notices.md @@ -1,5 +1,4 @@ -Server Notices -============== +# Server Notices 'Server Notices' are a new feature introduced in Synapse 0.30. They provide a channel whereby server administrators can send messages to users on the server. @@ -11,8 +10,7 @@ they may also find a use for features such as "Message of the day". This is a feature specific to Synapse, but it uses standard Matrix communication mechanisms, so should work with any Matrix client. -User experience ---------------- +## User experience When the user is first sent a server notice, they will get an invitation to a room (typically called 'Server Notices', though this is configurable in @@ -29,8 +27,7 @@ levels. Having joined the room, the user can leave the room if they want. Subsequent server notices will then cause a new room to be created. -Synapse configuration ---------------------- +## Synapse configuration Server notices come from a specific user id on the server. Server administrators are free to choose the user id - something like `server` is @@ -58,17 +55,7 @@ room which will be created. `system_mxid_display_name` and `system_mxid_avatar_url` can be used to set the displayname and avatar of the Server Notices user. -Sending notices ---------------- +## Sending notices -As of the current version of synapse, there is no convenient interface for -sending notices (other than the automated ones sent as part of consent -tracking). - -In the meantime, it is possible to test this feature using the manhole. Having -gone into the manhole as described in [manhole.md](manhole.md), a notice can be -sent with something like: - -``` ->>> hs.get_server_notices_manager().send_notice('@user:server.com', {'msgtype':'m.text', 'body':'foo'}) -``` +To send server notices to users you can use the +[admin_api](admin_api/server_notices.md). diff --git a/docs/user_directory.md b/docs/user_directory.md index 4c8ee44f3..e64aa453c 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,11 +7,7 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -quickest solution to fix it is: - -``` -UPDATE user_directory_stream_pos SET stream_id = NULL; -``` - -and restart the synapse, which should then start a background task to +solution to fix it is to execute the SQL here +https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/delta/53/user_dir_populate.sql +and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 6b9be9906..93305ee9b 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -24,6 +24,7 @@ DISTS = ( "ubuntu:xenial", "ubuntu:bionic", "ubuntu:cosmic", + "ubuntu:disco", ) DESC = '''\ diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index da027be26..62e5a0747 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -20,9 +20,7 @@ class CallVisitor(ast.NodeVisitor): else: return - if name == "client_path_patterns": - PATTERNS_V1.append(node.args[0].s) - elif name == "client_v2_patterns": + if name == "client_patterns": PATTERNS_V2.append(node.args[0].s) diff --git a/synapse/__init__.py b/synapse/__init__.py index 6bb5a8b24..d0e8d7c21 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -27,4 +27,4 @@ try: except ImportError: pass -__version__ = "0.99.3" +__version__ = "0.99.5.2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 960e66dbd..0c6c93a87 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -556,7 +556,7 @@ class Auth(object): """ Check if the given user is a local server admin. Args: - user (str): mxid of user to check + user (UserID): user to check Returns: bool: True if the user is an admin diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 0860b7590..ee129c868 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -20,6 +20,12 @@ # the "depth" field on events is limited to 2**63 - 1 MAX_DEPTH = 2**63 - 1 +# the maximum length for a room alias is 255 characters +MAX_ALIAS_LENGTH = 255 + +# the maximum length for a user id is 255 characters +MAX_USERID_LENGTH = 255 + class Membership(object): @@ -73,6 +79,7 @@ class EventTypes(object): RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" + Encryption = "m.room.encryption" RoomAvatar = "m.room.avatar" RoomEncryption = "m.room.encryption" GuestAccess = "m.room.guest_access" @@ -113,3 +120,11 @@ class UserTypes(object): """ SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) + + +class RelationTypes(object): + """The types of relations known to this server. + """ + ANNOTATION = "m.annotation" + REPLACE = "m.replace" + REFERENCE = "m.reference" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ff89259de..e91697049 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -328,9 +328,23 @@ class RoomKeysVersionError(SynapseError): self.current_version = current_version -class IncompatibleRoomVersionError(SynapseError): - """A server is trying to join a room whose version it does not support.""" +class UnsupportedRoomVersionError(SynapseError): + """The client's request to create a room used a room version that the server does + not support.""" + def __init__(self): + super(UnsupportedRoomVersionError, self).__init__( + code=400, + msg="Homeserver does not support this room version", + errcode=Codes.UNSUPPORTED_ROOM_VERSION, + ) + +class IncompatibleRoomVersionError(SynapseError): + """A server is trying to join a room whose version it does not support. + + Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join + failing. + """ def __init__(self, room_version): super(IncompatibleRoomVersionError, self).__init__( code=400, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index e77abe104..4085bd10b 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,13 +19,15 @@ class EventFormatVersions(object): """This is an internal enum for tracking the version of the event format, independently from the room version. """ - V1 = 1 # $id:server format - V2 = 2 # MSC1659-style $hash format: introduced for room v3 + V1 = 1 # $id:server event id format + V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 + V3 = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { EventFormatVersions.V1, EventFormatVersions.V2, + EventFormatVersions.V3, } @@ -75,10 +77,12 @@ class RoomVersions(object): EventFormatVersions.V2, StateResolutionVersions.V2, ) - - -# the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = RoomVersions.V1 + V4 = RoomVersion( + "4", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + ) KNOWN_ROOM_VERSIONS = { @@ -87,5 +91,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V2, RoomVersions.V3, RoomVersions.STATE_V2_TEST, + RoomVersions.V4, ) } # type: dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index cb71d8087..e16c386a1 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -22,11 +22,11 @@ from six.moves.urllib.parse import urlencode from synapse.config import ConfigError -CLIENT_PREFIX = "/_matrix/client/api/v1" -CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" +CLIENT_API_PREFIX = "/_matrix/client" FEDERATION_PREFIX = "/_matrix/federation" FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" +FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index d4c6c4c8e..8cc990399 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,13 +22,14 @@ import traceback import psutil from daemonize import Daemonize -from twisted.internet import error, reactor +from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error from synapse.crypto import context_factory from synapse.util import PreserveLoggingContext +from synapse.util.async_helpers import Linearizer from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string @@ -99,6 +100,8 @@ def start_reactor( logger (logging.Logger): logger instance to pass to Daemonize """ + install_dns_limiter(reactor) + def run(): # make sure that we run the reactor with the sentinel log context, # otherwise other PreserveLoggingContext instances will get confused @@ -312,3 +315,87 @@ def setup_sentry(hs): name = hs.config.worker_name if hs.config.worker_name else "master" scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) + + +def install_dns_limiter(reactor, max_dns_requests_in_flight=100): + """Replaces the resolver with one that limits the number of in flight DNS + requests. + + This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we + can run out of file descriptors and infinite loop if we attempt to do too + many DNS queries at once + """ + new_resolver = _LimitedHostnameResolver( + reactor.nameResolver, max_dns_requests_in_flight, + ) + + reactor.installNameResolver(new_resolver) + + +class _LimitedHostnameResolver(object): + """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. + """ + + def __init__(self, resolver, max_dns_requests_in_flight): + self._resolver = resolver + self._limiter = Linearizer( + name="dns_client_limiter", max_count=max_dns_requests_in_flight, + ) + + def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, + addressTypes=None, transportSemantics='TCP'): + # We need this function to return `resolutionReceiver` so we do all the + # actual logic involving deferreds in a separate function. + + # even though this is happening within the depths of twisted, we need to drop + # our logcontext before starting _resolve, otherwise: (a) _resolve will drop + # the logcontext if it returns an incomplete deferred; (b) _resolve will + # call the resolutionReceiver *with* a logcontext, which it won't be expecting. + with PreserveLoggingContext(): + self._resolve( + resolutionReceiver, + hostName, + portNumber, + addressTypes, + transportSemantics, + ) + + return resolutionReceiver + + @defer.inlineCallbacks + def _resolve(self, resolutionReceiver, hostName, portNumber=0, + addressTypes=None, transportSemantics='TCP'): + + with (yield self._limiter.queue(())): + # resolveHostName doesn't return a Deferred, so we need to hook into + # the receiver interface to get told when resolution has finished. + + deferred = defer.Deferred() + receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred) + + self._resolver.resolveHostName( + receiver, hostName, portNumber, + addressTypes, transportSemantics, + ) + + yield deferred + + +class _DeferredResolutionReceiver(object): + """Wraps a IResolutionReceiver and simply resolves the given deferred when + resolution is complete + """ + + def __init__(self, receiver, deferred): + self._receiver = receiver + self._deferred = deferred + + def resolutionBegan(self, resolutionInProgress): + self._receiver.resolutionBegan(resolutionInProgress) + + def addressResolved(self, address): + self._receiver.addressResolved(address) + + def resolutionComplete(self): + self._deferred.callback(()) + self._receiver.resolutionComplete() diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 864f1eac4..a16e037f3 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -38,6 +38,7 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore @@ -81,6 +82,7 @@ class ClientReaderSlavedStore( SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedTransactionStore, + SlavedProfileStore, SlavedClientIpStore, BaseSlavedStore, ): diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 8479fee73..6504da527 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns -from synapse.rest.client.v2_alpha._base import client_v2_patterns +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree @@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.frontend_proxy") -class PresenceStatusStubServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/presence/(?P[^/]*)/status") +class PresenceStatusStubServlet(RestServlet): + PATTERNS = client_patterns("/presence/(?P[^/]*)/status") def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__(hs) + super(PresenceStatusStubServlet, self).__init__() self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() self.main_uri = hs.config.worker_main_http_uri @@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet): class KeyUploadServlet(RestServlet): - PATTERNS = client_v2_patterns("/keys/upload(/(?P[^/]+))?$") + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") def __init__(self, hs): """ diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 79be977ea..1045d2894 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource +from synapse.rest.admin import AdminRestResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource @@ -180,6 +181,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), + "/_synapse/admin": AdminRestResource(self), }) if self.get_config().saml2_enabled: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 342a6ce5f..8400471f4 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,12 +31,50 @@ logger = logging.getLogger(__name__) class EmailConfig(Config): def read_config(self, config): + # TODO: We should separate better the email configuration from the notification + # and account validity config. + self.email_enable_notifs = False email_config = config.get("email", {}) - self.email_enable_notifs = email_config.get("enable_notifs", False) - if self.email_enable_notifs: + self.email_smtp_host = email_config.get("smtp_host", None) + self.email_smtp_port = email_config.get("smtp_port", None) + self.email_smtp_user = email_config.get("smtp_user", None) + self.email_smtp_pass = email_config.get("smtp_pass", None) + self.require_transport_security = email_config.get( + "require_transport_security", False + ) + if "app_name" in email_config: + self.email_app_name = email_config["app_name"] + else: + self.email_app_name = "Matrix" + + self.email_notif_from = email_config.get("notif_from", None) + if self.email_notif_from is not None: + # make sure it's valid + parsed = email.utils.parseaddr(self.email_notif_from) + if parsed[1] == '': + raise RuntimeError("Invalid notif_from address") + + template_dir = email_config.get("template_dir") + # we need an absolute path, because we change directory after starting (and + # we don't yet know what auxilliary templates like mail.css we will need). + # (Note that loading as package_resources with jinja.PackageLoader doesn't + # work for the same reason.) + if not template_dir: + template_dir = pkg_resources.resource_filename( + 'synapse', 'res/templates' + ) + + self.email_template_dir = os.path.abspath(template_dir) + + self.email_enable_notifs = email_config.get("enable_notifs", False) + account_validity_renewal_enabled = config.get( + "account_validity", {}, + ).get("renew_at") + + if self.email_enable_notifs or account_validity_renewal_enabled: # make sure we can import the required deps import jinja2 import bleach @@ -42,6 +82,7 @@ class EmailConfig(Config): jinja2 bleach + if self.email_enable_notifs: required = [ "smtp_host", "smtp_port", @@ -66,34 +107,13 @@ class EmailConfig(Config): "email.enable_notifs is True but no public_baseurl is set" ) - self.email_smtp_host = email_config["smtp_host"] - self.email_smtp_port = email_config["smtp_port"] - self.email_notif_from = email_config["notif_from"] self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_text = email_config["notif_template_text"] - self.email_expiry_template_html = email_config.get( - "expiry_template_html", "notice_expiry.html", - ) - self.email_expiry_template_text = email_config.get( - "expiry_template_text", "notice_expiry.txt", - ) - - template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxilliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename( - 'synapse', 'res/templates' - ) - template_dir = os.path.abspath(template_dir) for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(template_dir, f) + p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): raise ConfigError("Unable to find email template file %s" % (p, )) - self.email_template_dir = template_dir self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -101,29 +121,24 @@ class EmailConfig(Config): self.email_riot_base_url = email_config.get( "riot_base_url", None ) - self.email_smtp_user = email_config.get( - "smtp_user", None - ) - self.email_smtp_pass = email_config.get( - "smtp_pass", None - ) - self.require_transport_security = email_config.get( - "require_transport_security", False - ) - if "app_name" in email_config: - self.email_app_name = email_config["app_name"] - else: - self.email_app_name = "Matrix" - - # make sure it's valid - parsed = email.utils.parseaddr(self.email_notif_from) - if parsed[1] == '': - raise RuntimeError("Invalid notif_from address") else: self.email_enable_notifs = False # Not much point setting defaults for the rest: it would be an # error for them to be used. + if account_validity_renewal_enabled: + self.email_expiry_template_html = email_config.get( + "expiry_template_html", "notice_expiry.html", + ) + self.email_expiry_template_text = email_config.get( + "expiry_template_text", "notice_expiry.txt", + ) + + for f in self.email_expiry_template_text, self.email_expiry_template_html: + p = os.path.join(self.email_template_dir, f) + if not os.path.isfile(p): + raise ConfigError("Unable to find email template file %s" % (p, )) + def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable sending emails for notification events or expiry notices diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 727fdc54d..5c4fc8ff2 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from .api import ApiConfig from .appservice import AppServiceConfig from .captcha import CaptchaConfig @@ -36,20 +37,41 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .stats import StatsConfig from .tls import TlsConfig from .user_directory import UserDirectoryConfig from .voip import VoipConfig from .workers import WorkerConfig -class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig, - RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, - AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, - ConsentConfig, - ServerNoticesConfig, RoomDirectoryConfig, - ): +class HomeServerConfig( + ServerConfig, + TlsConfig, + DatabaseConfig, + LoggingConfig, + RatelimitConfig, + ContentRepositoryConfig, + CaptchaConfig, + VoipConfig, + RegistrationConfig, + MetricsConfig, + ApiConfig, + AppServiceConfig, + KeyConfig, + SAML2Config, + CasConfig, + JWTConfig, + PasswordConfig, + EmailConfig, + WorkerConfig, + PasswordAuthProviderConfig, + PushConfig, + SpamCheckerConfig, + GroupsConfig, + UserDirectoryConfig, + ConsentConfig, + StatsConfig, + ServerNoticesConfig, + RoomDirectoryConfig, +): pass diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5a68399e6..5a9adac48 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -16,16 +16,56 @@ from ._base import Config class RateLimitConfig(object): - def __init__(self, config): - self.per_second = config.get("per_second", 0.17) - self.burst_count = config.get("burst_count", 3.0) + def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + self.per_second = config.get("per_second", defaults["per_second"]) + self.burst_count = config.get("burst_count", defaults["burst_count"]) + + +class FederationRateLimitConfig(object): + _items_and_default = { + "window_size": 10000, + "sleep_limit": 10, + "sleep_delay": 500, + "reject_limit": 50, + "concurrent": 3, + } + + def __init__(self, **kwargs): + for i in self._items_and_default.keys(): + setattr(self, i, kwargs.get(i) or self._items_and_default[i]) class RatelimitConfig(Config): - def read_config(self, config): - self.rc_messages_per_second = config.get("rc_messages_per_second", 0.2) - self.rc_message_burst_count = config.get("rc_message_burst_count", 10.0) + + # Load the new-style messages config if it exists. Otherwise fall back + # to the old method. + if "rc_message" in config: + self.rc_message = RateLimitConfig( + config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} + ) + else: + self.rc_message = RateLimitConfig( + { + "per_second": config.get("rc_messages_per_second", 0.2), + "burst_count": config.get("rc_message_burst_count", 10.0), + } + ) + + # Load the new-style federation config, if it exists. Otherwise, fall + # back to the old method. + if "federation_rc" in config: + self.rc_federation = FederationRateLimitConfig(**config["rc_federation"]) + else: + self.rc_federation = FederationRateLimitConfig( + **{ + "window_size": config.get("federation_rc_window_size"), + "sleep_limit": config.get("federation_rc_sleep_limit"), + "sleep_delay": config.get("federation_rc_sleep_delay"), + "reject_limit": config.get("federation_rc_reject_limit"), + "concurrent": config.get("federation_rc_concurrent"), + } + ) self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) @@ -33,38 +73,26 @@ class RatelimitConfig(Config): self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) self.rc_login_failed_attempts = RateLimitConfig( - rc_login_config.get("failed_attempts", {}), + rc_login_config.get("failed_attempts", {}) ) - self.federation_rc_window_size = config.get("federation_rc_window_size", 1000) - self.federation_rc_sleep_limit = config.get("federation_rc_sleep_limit", 10) - self.federation_rc_sleep_delay = config.get("federation_rc_sleep_delay", 500) - self.federation_rc_reject_limit = config.get("federation_rc_reject_limit", 50) - self.federation_rc_concurrent = config.get("federation_rc_concurrent", 3) - self.federation_rr_transactions_per_room_per_second = config.get( - "federation_rr_transactions_per_room_per_second", 50, + "federation_rr_transactions_per_room_per_second", 50 ) def default_config(self, **kwargs): return """\ ## Ratelimiting ## - # Number of messages a client can send per second - # - #rc_messages_per_second: 0.2 - - # Number of message a client can send before being throttled - # - #rc_message_burst_count: 10.0 - - # Ratelimiting settings for registration and login. + # Ratelimiting settings for client actions (registration, login, messaging). # # Each ratelimiting configuration is made of two parameters: # - per_second: number of requests a client can send per second. # - burst_count: number of requests a client can send before being throttled. # # Synapse currently uses the following configurations: + # - one for messages that ratelimits sending based on the account the client + # is using # - one for registration that ratelimits registration requests based on the # client's IP address. # - one for login that ratelimits login requests based on the client's IP @@ -77,6 +105,10 @@ class RatelimitConfig(Config): # # The defaults are as shown below. # + #rc_message: + # per_second: 0.2 + # burst_count: 10 + # #rc_registration: # per_second: 0.17 # burst_count: 3 @@ -92,29 +124,28 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 - # The federation window size in milliseconds - # - #federation_rc_window_size: 1000 - # The number of federation requests from a single server in a window - # before the server will delay processing the request. + # Ratelimiting settings for incoming federation # - #federation_rc_sleep_limit: 10 - - # The duration in milliseconds to delay processing events from - # remote servers by if they go over the sleep limit. + # The rc_federation configuration is made up of the following settings: + # - window_size: window size in milliseconds + # - sleep_limit: number of federation requests from a single server in + # a window before the server will delay processing the request. + # - sleep_delay: duration in milliseconds to delay processing events + # from remote servers by if they go over the sleep limit. + # - reject_limit: maximum number of concurrent federation requests + # allowed from a single server + # - concurrent: number of federation requests to concurrently process + # from a single server # - #federation_rc_sleep_delay: 500 - - # The maximum number of concurrent federation requests allowed - # from a single server + # The defaults are as shown below. # - #federation_rc_reject_limit: 50 - - # The number of federation requests to concurrently process from a - # single server - # - #federation_rc_concurrent: 3 + #rc_federation: + # window_size: 1000 + # sleep_limit: 10 + # sleep_delay: 500 + # reject_limit: 50 + # concurrent: 3 # Target outgoing federation transaction frequency for sending read-receipts, # per-room. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1309bce3e..aad340081 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -39,6 +39,8 @@ class AccountValidityConfig(Config): else: self.renew_email_subject = "Renew your %(app)s account" + self.startup_job_max_delta = self.period * 10. / 100. + if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: raise ConfigError("Can't send renewal emails without 'public_baseurl'") @@ -123,6 +125,16 @@ class RegistrationConfig(Config): # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter # from the ``email`` section. # + # Once this feature is enabled, Synapse will look for registered users without an + # expiration date at startup and will add one to every account it found using the + # current settings at that time. + # This means that, if a validity period is set, and Synapse is restarted (it will + # then derive an expiration date from the current validity period), and some time + # after that the validity period changes and Synapse is restarted, the users' + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period - d ; now + period], + # where d is equal to 10%% of the validity period. + # #account_validity: # enabled: True # period: 6w diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 3f34ad9b2..fbfcecc24 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -186,17 +186,21 @@ class ContentRepositoryConfig(Config): except ImportError: raise ConfigError(MISSING_NETADDR) - if "url_preview_ip_range_blacklist" in config: - self.url_preview_ip_range_blacklist = IPSet( - config["url_preview_ip_range_blacklist"] - ) - else: + if "url_preview_ip_range_blacklist" not in config: raise ConfigError( "For security, you must specify an explicit target IP address " "blacklist in url_preview_ip_range_blacklist for url previewing " "to work" ) + self.url_preview_ip_range_blacklist = IPSet( + config["url_preview_ip_range_blacklist"] + ) + + # we always blacklist '0.0.0.0' and '::', which are supposed to be + # unroutable addresses. + self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::']) + self.url_preview_ip_range_whitelist = IPSet( config.get("url_preview_ip_range_whitelist", ()) ) @@ -260,11 +264,12 @@ class ContentRepositoryConfig(Config): #thumbnail_sizes: %(formatted_thumbnail_sizes)s - # Is the preview URL API enabled? If enabled, you *must* specify - # an explicit url_preview_ip_range_blacklist of IPs that the spider is - # denied from accessing. + # Is the preview URL API enabled? # - #url_preview_enabled: false + # 'false' by default: uncomment the following to enable it (and specify a + # url_preview_ip_range_blacklist blacklist). + # + #url_preview_enabled: true # List of IP address CIDR ranges that the URL preview spider is denied # from accessing. There are no defaults: you must explicitly @@ -274,6 +279,12 @@ class ContentRepositoryConfig(Config): # synapse to issue arbitrary GET requests to your internal services, # causing serious security issues. # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + # This must be specified if url_preview_enabled is set. It is recommended that + # you uncomment the following list as a starting point. + # #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -284,7 +295,7 @@ class ContentRepositoryConfig(Config): # - '::1/128' # - 'fe80::/64' # - 'fc00::/7' - # + # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. # This is useful for specifying exceptions to wide-ranging blacklisted diff --git a/synapse/config/server.py b/synapse/config/server.py index c5e5679d5..e763e19e1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,9 @@ import logging import os.path +from netaddr import IPSet + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name from synapse.python_dependencies import DependencyException, check_requirements @@ -32,6 +36,8 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0'] +DEFAULT_ROOM_VERSION = "1" + class ServerConfig(Config): @@ -72,6 +78,35 @@ class ServerConfig(Config): # master, potentially causing inconsistency. self.enable_media_repo = config.get("enable_media_repo", True) + # Whether to require authentication to retrieve profile data (avatars, + # display names) of other users through the client API. + self.require_auth_for_profile_requests = config.get( + "require_auth_for_profile_requests", False, + ) + + # If set to 'True', requires authentication to access the server's + # public rooms directory through the client API, and forbids any other + # homeserver to fetch it via federation. + self.restrict_public_rooms_to_local_users = config.get( + "restrict_public_rooms_to_local_users", False, + ) + + default_room_version = config.get( + "default_room_version", DEFAULT_ROOM_VERSION, + ) + + # Ensure room version is a str + default_room_version = str(default_room_version) + + if default_room_version not in KNOWN_ROOM_VERSIONS: + raise ConfigError( + "Unknown default_room_version: %s, known room versions: %s" % + (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) + ) + + # Get the actual room version object rather than just the identifier + self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version] + # whether to enable search. If disabled, new entries will not be inserted # into the search tables and they will not be indexed. Users will receive # errors when attempting to search for messages. @@ -85,6 +120,11 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Whether to enable experimental MSC1849 (aka relations) support + self.experimental_msc1849_support_enabled = config.get( + "experimental_msc1849_support_enabled", False, + ) + # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 @@ -114,14 +154,34 @@ class ServerConfig(Config): # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( - "federation_domain_whitelist", None + "federation_domain_whitelist", None, ) - # turn the whitelist into a hash for speed of lookup + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup self.federation_domain_whitelist = {} + for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [], + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in " + "federation_ip_range_blacklist: %s" % e + ) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -132,6 +192,16 @@ class ServerConfig(Config): # sending out any replication updates. self.replication_torture_level = config.get("replication_torture_level") + # Whether to require a user to be in the room to add an alias to it. + # Defaults to True. + self.require_membership_for_aliases = config.get( + "require_membership_for_aliases", True, + ) + + # Whether to allow per-room membership profiles through the send of membership + # events with profile information that differ from the target's global profile. + self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + self.listeners = [] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -259,6 +329,10 @@ class ServerConfig(Config): unsecure_port = 8008 pid_file = os.path.join(data_dir_path, "homeserver.pid") + + # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the + # default config string + default_room_version = DEFAULT_ROOM_VERSION return """\ ## Server ## @@ -319,6 +393,30 @@ class ServerConfig(Config): # #use_presence: false + # Whether to require authentication to retrieve profile data (avatars, + # display names) of other users through the client API. Defaults to + # 'false'. Note that profile data is also available via the federation + # API, so this setting is of limited value if federation is enabled on + # the server. + # + #require_auth_for_profile_requests: true + + # If set to 'true', requires authentication to access the server's + # public rooms directory through the client API, and forbids any other + # homeserver to fetch it via federation. Defaults to 'false'. + # + #restrict_public_rooms_to_local_users: true + + # The default room version for newly created rooms. + # + # Known room versions are listed here: + # https://matrix.org/docs/spec/#complete-list-of-room-versions + # + # For example, for room version 1, default_room_version should be set + # to "1". + # + #default_room_version: "%(default_room_version)s" + # The GC threshold parameters to pass to `gc.set_threshold`, if defined # #gc_thresholds: [700, 10, 10] @@ -351,6 +449,24 @@ class ServerConfig(Config): # - nyc.example.com # - syd.example.com + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -386,8 +502,8 @@ class ServerConfig(Config): # # Valid resource names are: # - # client: the client-server API (/_matrix/client). Also implies 'media' and - # 'static'. + # client: the client-server API (/_matrix/client), and the synapse admin + # API (/_synapse/admin). Also implies 'media' and 'static'. # # consent: user consent forms (/_matrix/consent). See # docs/consent_tracking.md. @@ -488,6 +604,17 @@ class ServerConfig(Config): # Used by phonehome stats to group together related servers. #server_context: context + + # Whether to require a user to be in the room to add an alias to it. + # Defaults to 'true'. + # + #require_membership_for_aliases: false + + # Whether to allow per-room membership profiles through the send of membership + # events with profile information that differ from the target's global profile. + # Defaults to 'true'. + # + #allow_per_room_profiles: false """ % locals() def read_arguments(self, args): diff --git a/synapse/config/stats.py b/synapse/config/stats.py new file mode 100644 index 000000000..80fc1b9dd --- /dev/null +++ b/synapse/config/stats.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys + +from ._base import Config + + +class StatsConfig(Config): + """Stats Configuration + Configuration for the behaviour of synapse's stats engine + """ + + def read_config(self, config): + self.stats_enabled = True + self.stats_bucket_size = 86400 + self.stats_retention = sys.maxsize + stats_config = config.get("stats", None) + if stats_config: + self.stats_enabled = stats_config.get("enabled", self.stats_enabled) + self.stats_bucket_size = ( + self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000 + ) + self.stats_retention = ( + self.parse_duration( + stats_config.get("retention", "%ds" % (sys.maxsize,)) + ) + / 1000 + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Local statistics collection. Used in populating the room directory. + # + # 'bucket_size' controls how large each statistics timeslice is. It can + # be defined in a human readable short form -- e.g. "1d", "1y". + # + # 'retention' controls how long historical statistics will be kept for. + # It can be defined in a human readable short form -- e.g. "1d", "1y". + # + # + #stats: + # enabled: true + # bucket_size: 1d + # retention: 1y + """ diff --git a/synapse/config/tls.py b/synapse/config/tls.py index f0014902d..72dd5926f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -24,8 +24,10 @@ import six from unpaddedbase64 import encode_base64 from OpenSSL import crypto +from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError +from synapse.util import glob_to_regex logger = logging.getLogger(__name__) @@ -70,6 +72,53 @@ class TlsConfig(Config): self.tls_fingerprints = list(self._original_tls_fingerprints) + # Whether to verify certificates on outbound federation traffic + self.federation_verify_certificates = config.get( + "federation_verify_certificates", False, + ) + + # Whitelist of domains to not verify certificates for + fed_whitelist_entries = config.get( + "federation_certificate_verification_whitelist", [], + ) + + # Support globs (*) in whitelist values + self.federation_certificate_verification_whitelist = [] + for entry in fed_whitelist_entries: + # Convert globs to regex + entry_regex = glob_to_regex(entry) + self.federation_certificate_verification_whitelist.append(entry_regex) + + # List of custom certificate authorities for federation traffic validation + custom_ca_list = config.get( + "federation_custom_ca_list", None, + ) + + # Read in and parse custom CA certificates + self.federation_ca_trust_root = None + if custom_ca_list is not None: + if len(custom_ca_list) == 0: + # A trustroot cannot be generated without any CA certificates. + # Raise an error if this option has been specified without any + # corresponding certificates. + raise ConfigError("federation_custom_ca_list specified without " + "any certificate files") + + certs = [] + for ca_file in custom_ca_list: + logger.debug("Reading custom CA certificate file: %s", ca_file) + content = self.read_file(ca_file) + + # Parse the CA certificates + try: + cert_base = Certificate.loadPEM(content) + certs.append(cert_base) + except Exception as e: + raise ConfigError("Error parsing custom CA certificate file %s: %s" + % (ca_file, e)) + + self.federation_ca_trust_root = trustRootFromCertificates(certs) + # This config option applies to non-federation HTTP clients # (e.g. for talking to recaptcha, identity servers, and such) # It should never be used in production, and is intended for @@ -99,15 +148,15 @@ class TlsConfig(Config): try: with open(self.tls_certificate_file, 'rb') as f: cert_pem = f.read() - except Exception: - logger.exception("Failed to read existing certificate off disk!") - raise + except Exception as e: + raise ConfigError("Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e)) try: tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) - except Exception: - logger.exception("Failed to parse existing certificate off disk!") - raise + except Exception as e: + raise ConfigError("Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e)) if not allow_self_signed: if tls_certificate.get_subject() == tls_certificate.get_issuer(): @@ -192,6 +241,40 @@ class TlsConfig(Config): # #tls_private_key_path: "%(tls_private_key_path)s" + # Whether to verify TLS certificates when sending federation traffic. + # + # This currently defaults to `false`, however this will change in + # Synapse 1.0 when valid federation certificates will be required. + # + #federation_verify_certificates: true + + # Skip federation certificate verification on the following whitelist + # of domains. + # + # This setting should only be used in very specific cases, such as + # federation over Tor hidden services and similar. For private networks + # of homeservers, you likely want to use a private CA instead. + # + # Only effective if federation_verify_certicates is `true`. + # + #federation_certificate_verification_whitelist: + # - lon.example.com + # - *.domain.com + # - *.onion + + # List of custom certificate authorities for federation traffic. + # + # This setting should only normally be used within a private network of + # homeservers. + # + # Note that this list will replace those that are provided by your + # operating environment. Certificates must be in PEM format. + # + #federation_custom_ca_list: + # - myCA1.pem + # - myCA2.pem + # - myCA3.pem + # ACME support: This will configure Synapse to request a valid TLS certificate # for your configured `server_name` via Let's Encrypt. # diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 142754a7d..023997ccd 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -43,9 +43,9 @@ class UserDirectoryConfig(Config): # # 'search_all_users' defines whether to search all users visible to your HS # when searching the user directory, rather than limiting to users visible - # in public rooms. Defaults to false. If you set it True, you'll have to run - # UPDATE user_directory_stream_pos SET stream_id = NULL; - # on your database to tell it to rebuild the user_directory search indexes. + # in public rooms. Defaults to false. If you set it True, you'll have to + # rebuild the user_directory search indexes, see + # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # #user_directory: # enabled: true diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 49cbc7098..59ea087e6 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -18,10 +18,10 @@ import logging from zope.interface import implementer from OpenSSL import SSL, crypto -from twisted.internet._sslverify import _defaultCurveName +from twisted.internet._sslverify import ClientTLSOptions, _defaultCurveName from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator -from twisted.internet.ssl import CertificateOptions, ContextFactory +from twisted.internet.ssl import CertificateOptions, ContextFactory, platformTrust from twisted.python.failure import Failure logger = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def _tolerateErrors(wrapped): @implementer(IOpenSSLClientConnectionCreator) -class ClientTLSOptions(object): +class ClientTLSOptionsNoVerify(object): """ Client creator for TLS without certificate identity verification. This is a copy of twisted.internet._sslverify.ClientTLSOptions with the identity @@ -127,9 +127,30 @@ class ClientTLSOptionsFactory(object): to remote servers for federation.""" def __init__(self, config): - # We don't use config options yet - self._options = CertificateOptions(verify=False) + self._config = config + self._options_noverify = CertificateOptions() + + # Check if we're using a custom list of a CA certificates + trust_root = config.federation_ca_trust_root + if trust_root is None: + # Use CA root certs provided by OpenSSL + trust_root = platformTrust() + + self._options_verify = CertificateOptions(trustRoot=trust_root) def get_options(self, host): # Use _makeContext so that we get a fresh OpenSSL CTX each time. - return ClientTLSOptions(host, self._options._makeContext()) + + # Check if certificate verification has been enabled + should_verify = self._config.federation_verify_certificates + + # Check if we've disabled certificate verification for this host + if should_verify: + for regex in self._config.federation_certificate_verification_whitelist: + if regex.match(host): + should_verify = False + break + + if should_verify: + return ClientTLSOptions(host, self._options_verify._makeContext()) + return ClientTLSOptionsNoVerify(host, self._options_noverify._makeContext()) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 1dfa727fc..99a586655 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -31,7 +31,11 @@ logger = logging.getLogger(__name__) def check_event_content_hash(event, hash_algorithm=hashlib.sha256): """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) - logger.debug("Expecting hash: %s", encode_base64(expected_hash)) + logger.debug( + "Verifying content hash on %s (expecting: %s)", + event.event_id, + encode_base64(expected_hash), + ) # some malformed events lack a 'hashes'. Protect against it being missing # or a weird type by basically treating it the same as an unhashed event. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ed2e99443..e94e71bda 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -15,12 +15,13 @@ # limitations under the License. import logging -from collections import namedtuple +from collections import defaultdict +import six from six import raise_from from six.moves import urllib -import nacl.signing +import attr from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -43,7 +44,9 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext, unwrapFirstError +from synapse.util.async_helpers import yieldable_gather_results from synapse.util.logcontext import ( LoggingContext, PreserveLoggingContext, @@ -56,22 +59,36 @@ from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) -VerifyKeyRequest = namedtuple("VerifyRequest", ( - "server_name", "key_ids", "json_object", "deferred" -)) -""" -A request for a verify key to verify a JSON object. +@attr.s(slots=True, cmp=False) +class VerifyKeyRequest(object): + """ + A request for a verify key to verify a JSON object. -Attributes: - server_name(str): The name of the server to verify against. - key_ids(set(str)): The set of key_ids to that could be used to verify the - JSON object - json_object(dict): The JSON object to verify. - deferred(Deferred[str, str, nacl.signing.VerifyKey]): - A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched. The deferreds' callbacks are run with no - logcontext. -""" + Attributes: + server_name(str): The name of the server to verify against. + + key_ids(set[str]): The set of key_ids to that could be used to verify the + JSON object + + json_object(dict): The JSON object to verify. + + minimum_valid_until_ts (int): time at which we require the signing key to + be valid. (0 implies we don't care) + + key_ready (Deferred[str, str, nacl.signing.VerifyKey]): + A deferred (server_name, key_id, verify_key) tuple that resolves when + a verify key has been fetched. The deferreds' callbacks are run with no + logcontext. + + If we are unable to find a key which satisfies the request, the deferred + errbacks with an M_UNAUTHORIZED SynapseError. + """ + + server_name = attr.ib() + key_ids = attr.ib() + json_object = attr.ib() + minimum_valid_until_ts = attr.ib() + key_ready = attr.ib(default=attr.Factory(defer.Deferred)) class KeyLookupError(ValueError): @@ -79,13 +96,16 @@ class KeyLookupError(ValueError): class Keyring(object): - def __init__(self, hs): - self.store = hs.get_datastore() + def __init__(self, hs, key_fetchers=None): self.clock = hs.get_clock() - self.client = hs.get_http_client() - self.config = hs.get_config() - self.perspective_servers = self.config.perspectives - self.hs = hs + + if key_fetchers is None: + key_fetchers = ( + StoreKeyFetcher(hs), + PerspectivesKeyFetcher(hs), + ServerKeyFetcher(hs), + ) + self._key_fetchers = key_fetchers # map from server name to Deferred. Has an entry for each server with # an ongoing key download; the Deferred completes once the download @@ -94,11 +114,25 @@ class Keyring(object): # These are regular, logcontext-agnostic Deferreds. self.key_downloads = {} - def verify_json_for_server(self, server_name, json_object): + def verify_json_for_server(self, server_name, json_object, validity_time): + """Verify that a JSON object has been signed by a given server + + Args: + server_name (str): name of the server which must have signed this object + + json_object (dict): object to be checked + + validity_time (int): timestamp at which we require the signing key to + be valid. (0 implies we don't care) + + Returns: + Deferred[None]: completes if the the object was correctly signed, otherwise + errbacks with an error + """ + req = server_name, json_object, validity_time + return logcontext.make_deferred_yieldable( - self.verify_json_objects_for_server( - [(server_name, json_object)] - )[0] + self.verify_json_objects_for_server((req,))[0] ) def verify_json_objects_for_server(self, server_and_json): @@ -106,53 +140,71 @@ class Keyring(object): necessary. Args: - server_and_json (list): List of pairs of (server_name, json_object) + server_and_json (iterable[Tuple[str, dict, int]): + Iterable of triplets of (server_name, json_object, validity_time) + validity_time is a timestamp at which the signing key must be valid. Returns: - List: for each input pair, a deferred indicating success + List: for each input triplet, a deferred indicating success or failure to verify each json object's signature for the given server_name. The deferreds run their callbacks in the sentinel logcontext. """ + # a list of VerifyKeyRequests verify_requests = [] + handle = preserve_fn(_handle_key_deferred) - for server_name, json_object in server_and_json: + def process(server_name, json_object, validity_time): + """Process an entry in the request list + Given a (server_name, json_object, validity_time) triplet from the request + list, adds a key request to verify_requests, and returns a deferred which + will complete or fail (in the sentinel context) when verification completes. + """ key_ids = signature_ids(json_object, server_name) + if not key_ids: - logger.warn("Request from %s: no supported signature keys", - server_name) - deferred = defer.fail(SynapseError( - 400, - "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, - )) - else: - deferred = defer.Deferred() + return defer.fail( + SynapseError( + 400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED + ) + ) - logger.debug("Verifying for %s with key_ids %s", - server_name, key_ids) - - verify_request = VerifyKeyRequest( - server_name, key_ids, json_object, deferred + logger.debug( + "Verifying for %s with key_ids %s, min_validity %i", + server_name, + key_ids, + validity_time, ) + # add the key request to the queue, but don't start it off yet. + verify_request = VerifyKeyRequest( + server_name, key_ids, json_object, validity_time + ) verify_requests.append(verify_request) - run_in_background(self._start_key_lookups, verify_requests) + # now run _handle_key_deferred, which will wait for the key request + # to complete and then do the verification. + # + # We want _handle_key_request to log to the right context, so we + # wrap it with preserve_fn (aka run_in_background) + return handle(verify_request) - # Pass those keys to handle_key_deferred so that the json object - # signatures can be verified - handle = preserve_fn(_handle_key_deferred) - return [ - handle(rq) for rq in verify_requests + results = [ + process(server_name, json_object, validity_time) + for server_name, json_object, validity_time in server_and_json ] + if verify_requests: + run_in_background(self._start_key_lookups, verify_requests) + + return results + @defer.inlineCallbacks def _start_key_lookups(self, verify_requests): """Sets off the key fetches for each verify request - Once each fetch completes, verify_request.deferred will be resolved. + Once each fetch completes, verify_request.key_ready will be resolved. Args: verify_requests (List[VerifyKeyRequest]): @@ -165,16 +217,12 @@ class Keyring(object): # any other lookups until we have finished. # The deferreds are called with no logcontext. server_to_deferred = { - rq.server_name: defer.Deferred() - for rq in verify_requests + rq.server_name: defer.Deferred() for rq in verify_requests } # We want to wait for any previous lookups to complete before # proceeding. - yield self.wait_for_previous_lookups( - [rq.server_name for rq in verify_requests], - server_to_deferred, - ) + yield self.wait_for_previous_lookups(server_to_deferred) # Actually start fetching keys. self._get_server_verify_keys(verify_requests) @@ -202,19 +250,16 @@ class Keyring(object): return res for verify_request in verify_requests: - verify_request.deferred.addBoth( - remove_deferreds, verify_request, - ) + verify_request.key_ready.addBoth(remove_deferreds, verify_request) except Exception: logger.exception("Error starting key lookups") @defer.inlineCallbacks - def wait_for_previous_lookups(self, server_names, server_to_deferred): + def wait_for_previous_lookups(self, server_to_deferred): """Waits for any previous key lookups for the given servers to finish. Args: - server_names (list): list of server_names we want to lookup - server_to_deferred (dict): server_name to deferred which gets + server_to_deferred (dict[str, Deferred]): server_name to deferred which gets resolved once we've finished looking up keys for that server. The Deferreds should be regular twisted ones which call their callbacks with no logcontext. @@ -227,14 +272,15 @@ class Keyring(object): while True: wait_on = [ (server_name, self.key_downloads[server_name]) - for server_name in server_names + for server_name in server_to_deferred.keys() if server_name in self.key_downloads ] if not wait_on: break logger.info( "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], loop_count, + [w[0] for w in wait_on], + loop_count, ) with PreserveLoggingContext(): yield defer.DeferredList((w[1] for w in wait_on)) @@ -257,7 +303,7 @@ class Keyring(object): def _get_server_verify_keys(self, verify_requests): """Tries to find at least one key for each verify request - For each verify_request, verify_request.deferred is called back with + For each verify_request, verify_request.key_ready is called back with params (server_name, key_id, VerifyKey) if a key is found, or errbacked with a SynapseError if none of the keys are found. @@ -265,300 +311,151 @@ class Keyring(object): verify_requests (list[VerifyKeyRequest]): list of verify requests """ - # These are functions that produce keys given a list of key ids - key_fetch_fns = ( - self.get_keys_from_store, # First try the local store - self.get_keys_from_perspectives, # Then try via perspectives - self.get_keys_from_server, # Then try directly + remaining_requests = set( + (rq for rq in verify_requests if not rq.key_ready.called) ) @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): - # dict[str, set(str)]: keys to fetch for each server - missing_keys = {} - for verify_request in verify_requests: - missing_keys.setdefault(verify_request.server_name, set()).update( - verify_request.key_ids - ) - - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) - - # We now need to figure out which verify requests we have keys - # for and which we don't - missing_keys = {} - requests_missing_keys = [] - for verify_request in verify_requests: - if verify_request.deferred.called: - # We've already called this deferred, which probably - # means that we've already found a key for it. - continue - - server_name = verify_request.server_name - - # see if any of the keys we got this time are sufficient to - # complete this VerifyKeyRequest. - result_keys = results.get(server_name, {}) - for key_id in verify_request.key_ids: - key = result_keys.get(key_id) - if key: - with PreserveLoggingContext(): - verify_request.deferred.callback( - (server_name, key_id, key) - ) - break - else: - # The else block is only reached if the loop above - # doesn't break. - missing_keys.setdefault(server_name, set()).update( - verify_request.key_ids - ) - requests_missing_keys.append(verify_request) - - if not missing_keys: - break + for f in self._key_fetchers: + if not remaining_requests: + return + yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) + # look for any requests which weren't satisfied with PreserveLoggingContext(): - for verify_request in requests_missing_keys: - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + for verify_request in remaining_requests: + verify_request.key_ready.errback( + SynapseError( + 401, + "No key for %s with ids in %s (min_validity %i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, + ), + Codes.UNAUTHORIZED, + ) + ) def on_err(err): + # we don't really expect to get here, because any errors should already + # have been caught and logged. But if we do, let's log the error and make + # sure that all of the deferreds are resolved. + logger.error("Unexpected error in _get_server_verify_keys: %s", err) with PreserveLoggingContext(): - for verify_request in verify_requests: - if not verify_request.deferred.called: - verify_request.deferred.errback(err) + for verify_request in remaining_requests: + if not verify_request.key_ready.called: + verify_request.key_ready.errback(err) run_in_background(do_iterations).addErrback(on_err) @defer.inlineCallbacks - def get_keys_from_store(self, server_name_and_key_ids): + def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + """Use a key fetcher to attempt to satisfy some key requests + + Args: + fetcher (KeyFetcher): fetcher to use to fetch the keys + remaining_requests (set[VerifyKeyRequest]): outstanding key requests. + Any successfully-completed requests will be removed from the list. + """ + # dict[str, dict[str, int]]: keys to fetch. + # server_name -> key_id -> min_valid_ts + missing_keys = defaultdict(dict) + + for verify_request in remaining_requests: + # any completed requests should already have been removed + assert not verify_request.key_ready.called + keys_for_server = missing_keys[verify_request.server_name] + + for key_id in verify_request.key_ids: + # If we have several requests for the same key, then we only need to + # request that key once, but we should do so with the greatest + # min_valid_until_ts of the requests, so that we can satisfy all of + # the requests. + keys_for_server[key_id] = max( + keys_for_server.get(key_id, -1), + verify_request.minimum_valid_until_ts + ) + + results = yield fetcher.get_keys(missing_keys) + + completed = list() + for verify_request in remaining_requests: + server_name = verify_request.server_name + + # see if any of the keys we got this time are sufficient to + # complete this VerifyKeyRequest. + result_keys = results.get(server_name, {}) + for key_id in verify_request.key_ids: + fetch_key_result = result_keys.get(key_id) + if not fetch_key_result: + # we didn't get a result for this key + continue + + if ( + fetch_key_result.valid_until_ts + < verify_request.minimum_valid_until_ts + ): + # key was not valid at this point + continue + + with PreserveLoggingContext(): + verify_request.key_ready.callback( + (server_name, key_id, fetch_key_result.verify_key) + ) + completed.append(verify_request) + break + + remaining_requests.difference_update(completed) + + +class KeyFetcher(object): + def get_keys(self, keys_to_fetch): """ Args: - server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): - list of (server_name, iterable[key_id]) tuples to fetch keys for + keys_to_fetch (dict[str, dict[str, int]]): + the keys to be fetched. server_name -> key_id -> min_valid_ts Returns: - Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from - server_name -> key_id -> VerifyKey + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + map from server_name -> key_id -> FetchKeyResult """ + raise NotImplementedError + + +class StoreKeyFetcher(KeyFetcher): + """KeyFetcher impl which fetches keys from our data store""" + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def get_keys(self, keys_to_fetch): + """see KeyFetcher.get_keys""" + keys_to_fetch = ( (server_name, key_id) - for server_name, key_ids in server_name_and_key_ids - for key_id in key_ids + for server_name, keys_for_server in keys_to_fetch.items() + for key_id in keys_for_server.keys() ) + res = yield self.store.get_server_verify_keys(keys_to_fetch) keys = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks - def get_keys_from_perspectives(self, server_name_and_key_ids): - @defer.inlineCallbacks - def get_key(perspective_name, perspective_keys): - try: - result = yield self.get_server_verify_key_v2_indirect( - server_name_and_key_ids, perspective_name, perspective_keys - ) - defer.returnValue(result) - except KeyLookupError as e: - logger.warning( - "Key lookup failed from %r: %s", perspective_name, e, - ) - except Exception as e: - logger.exception( - "Unable to get key from %r: %s %s", - perspective_name, - type(e).__name__, str(e), - ) - defer.returnValue({}) - - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background(get_key, p_name, p_keys) - for p_name, p_keys in self.perspective_servers.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) - - union_of_keys = {} - for result in results: - for server_name, keys in result.items(): - union_of_keys.setdefault(server_name, {}).update(keys) - - defer.returnValue(union_of_keys) - - @defer.inlineCallbacks - def get_keys_from_server(self, server_name_and_key_ids): - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.get_server_verify_key_v2_direct, - server_name, - key_ids, - ) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) - - merged = {} - for result in results: - merged.update(result) - - defer.returnValue({ - server_name: keys - for server_name, keys in merged.items() - if keys - }) - - @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, - perspective_name, - perspective_keys): - # TODO(mark): Set the minimum_valid_until_ts to that needed by - # the events being validated or the current time if validating - # an incoming request. - try: - query_response = yield self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - u"server_keys": { - server_name: { - key_id: { - u"minimum_valid_until_ts": 0 - } for key_id in key_ids - } - for server_name, key_ids in server_names_and_key_ids - } - }, - long_retries=True, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - raise_from( - KeyLookupError("Failed to connect to remote server"), e, - ) - except HttpResponseException as e: - raise_from( - KeyLookupError("Remote server returned an error"), e, - ) - - keys = {} - - responses = query_response["server_keys"] - - for response in responses: - if (u"signatures" not in response - or perspective_name not in response[u"signatures"]): - raise KeyLookupError( - "Key response not signed by perspective server" - " %r" % (perspective_name,) - ) - - verified = False - for key_id in response[u"signatures"][perspective_name]: - if key_id in perspective_keys: - verify_signed_json( - response, - perspective_name, - perspective_keys[key_id] - ) - verified = True - - if not verified: - logging.info( - "Response from perspective server %r not signed with a" - " known key, signed with: %r, known keys: %r", - perspective_name, - list(response[u"signatures"][perspective_name]), - list(perspective_keys) - ) - raise KeyLookupError( - "Response not signed with a known key for perspective" - " server %r" % (perspective_name,) - ) - - processed_response = yield self.process_v2_response( - perspective_name, response - ) - server_name = response["server_name"] - - keys.setdefault(server_name, {}).update(processed_response) - - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=server_name, - from_server=perspective_name, - verify_keys=response_keys, - ) - for server_name, response_keys in keys.items() - ], - consumeErrors=True - ).addErrback(unwrapFirstError)) - - defer.returnValue(keys) - - @defer.inlineCallbacks - def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} # type: dict[str, nacl.signing.VerifyKey] - - for requested_key_id in key_ids: - if requested_key_id in keys: - continue - - try: - response = yield self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - raise_from( - KeyLookupError("Failed to connect to remote server"), e, - ) - except HttpResponseException as e: - raise_from( - KeyLookupError("Remote server returned an error"), e, - ) - - if (u"signatures" not in response - or server_name not in response[u"signatures"]): - raise KeyLookupError("Key response not signed by remote server") - - if response["server_name"] != server_name: - raise KeyLookupError("Expected a response for server %r not %r" % ( - server_name, response["server_name"] - )) - - response_keys = yield self.process_v2_response( - from_server=server_name, - requested_ids=[requested_key_id], - response_json=response, - ) - - keys.update(response_keys) - - yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=keys, - ) - defer.returnValue({server_name: keys}) +class BaseV2KeyFetcher(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.config = hs.get_config() @defer.inlineCallbacks def process_v2_response( - self, from_server, response_json, requested_ids=[], + self, from_server, response_json, time_added_ms ): """Parse a 'Server Keys' structure from the result of a /key request @@ -567,8 +464,7 @@ class Keyring(object): POST /_matrix/key/v2/query. Checks that each signature in the response that claims to come from the origin - server is valid. (Does not check that there actually is such a signature, for - some reason.) + server is valid, and that there is at least one such signature. Stores the json in server_keys_json so that it can be used for future responses to /_matrix/key/v2/query. @@ -580,103 +476,375 @@ class Keyring(object): response_json (dict): the json-decoded Server Keys response object - requested_ids (iterable[str]): a list of the key IDs that were requested. - We will store the json for these key ids as well as any that are - actually in the response + time_added_ms (int): the timestamp to record in server_keys_json Returns: - Deferred[dict[str, nacl.signing.VerifyKey]]: - map from key_id to key object + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object """ - time_now_ms = self.clock.time_msec() - response_keys = {} + ts_valid_until_ms = response_json[u"valid_until_ts"] + + # start by extracting the keys from the response, since they may be required + # to validate the signature on the response. verify_keys = {} for key_id, key_data in response_json["verify_keys"].items(): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_key.time_added = time_now_ms - verify_keys[key_id] = verify_key + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=ts_valid_until_ms + ) + + server_name = response_json["server_name"] + verified = False + for key_id in response_json["signatures"].get(server_name, {}): + # each of the keys used for the signature must be present in the response + # json. + key = verify_keys.get(key_id) + if not key: + raise KeyLookupError( + "Key response is signed by key id %s:%s but that key is not " + "present in the response" % (server_name, key_id) + ) + + verify_signed_json(response_json, server_name, key.verify_key) + verified = True + + if not verified: + raise KeyLookupError( + "Key response for %s is not signed by the origin server" + % (server_name,) + ) - old_verify_keys = {} for key_id, key_data in response_json["old_verify_keys"].items(): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_key.expired = key_data["expired_ts"] - verify_key.time_added = time_now_ms - old_verify_keys[key_id] = verify_key - - server_name = response_json["server_name"] - for key_id in response_json["signatures"].get(server_name, {}): - if key_id not in response_json["verify_keys"]: - raise KeyLookupError( - "Key response must include verification keys for all" - " signatures" - ) - if key_id in verify_keys: - verify_signed_json( - response_json, - server_name, - verify_keys[key_id] + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=key_data["expired_ts"] ) + # re-sign the json with our own key, so that it is ready if we are asked to + # give it out as a notary server signed_key_json = sign_json( - response_json, - self.config.server_name, - self.config.signing_key[0], + response_json, self.config.server_name, self.config.signing_key[0] ) signed_key_json_bytes = encode_canonical_json(signed_key_json) - ts_valid_until_ms = signed_key_json[u"valid_until_ts"] - updated_key_ids = set(requested_ids) - updated_key_ids.update(verify_keys) - updated_key_ids.update(old_verify_keys) + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store.store_server_keys_json, + server_name=server_name, + key_id=key_id, + from_server=from_server, + ts_now_ms=time_added_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in verify_keys + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) - response_keys.update(verify_keys) - response_keys.update(old_verify_keys) + defer.returnValue(verify_keys) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_now_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, + +class PerspectivesKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the "perspectives" servers""" + + def __init__(self, hs): + super(PerspectivesKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + self.perspective_servers = self.config.perspectives + + @defer.inlineCallbacks + def get_keys(self, keys_to_fetch): + """see KeyFetcher.get_keys""" + + @defer.inlineCallbacks + def get_key(perspective_name, perspective_keys): + try: + result = yield self.get_server_verify_key_v2_indirect( + keys_to_fetch, perspective_name, perspective_keys + ) + defer.returnValue(result) + except KeyLookupError as e: + logger.warning("Key lookup failed from %r: %s", perspective_name, e) + except Exception as e: + logger.exception( + "Unable to get key from %r: %s %s", + perspective_name, + type(e).__name__, + str(e), ) - for key_id in updated_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) - defer.returnValue(response_keys) + defer.returnValue({}) - def store_keys(self, server_name, from_server, verify_keys): - """Store a collection of verify keys for a given server - Args: - server_name(str): The name of the server the keys are for. - from_server(str): The server the keys were downloaded from. - verify_keys(dict): A mapping of key_id to VerifyKey. - Returns: - A deferred that completes when the keys are stored. + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(get_key, p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + union_of_keys = {} + for result in results: + for server_name, keys in result.items(): + union_of_keys.setdefault(server_name, {}).update(keys) + + defer.returnValue(union_of_keys) + + @defer.inlineCallbacks + def get_server_verify_key_v2_indirect( + self, keys_to_fetch, perspective_name, perspective_keys + ): """ - # TODO(markjh): Store whether the keys have expired. - return logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.store_server_verify_key, - server_name, server_name, key.time_added, key + Args: + keys_to_fetch (dict[str, dict[str, int]]): + the keys to be fetched. server_name -> key_id -> min_valid_ts + + perspective_name (str): name of the notary server to query for the keys + + perspective_keys (dict[str, VerifyKey]): map of key_id->key for the + notary server + + Returns: + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + from server_name -> key_id -> FetchKeyResult + + Raises: + KeyLookupError if there was an error processing the entire response from + the server + """ + logger.info( + "Requesting keys %s from notary server %s", + keys_to_fetch.items(), + perspective_name, + ) + + try: + query_response = yield self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + u"server_keys": { + server_name: { + key_id: {u"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() + } + for server_name, server_keys in keys_to_fetch.items() + } + }, + ) + except (NotRetryingDestination, RequestSendFailed) as e: + raise_from(KeyLookupError("Failed to connect to remote server"), e) + except HttpResponseException as e: + raise_from(KeyLookupError("Remote server returned an error"), e) + + keys = {} + added_keys = [] + + time_now_ms = self.clock.time_msec() + + for response in query_response["server_keys"]: + # do this first, so that we can give useful errors thereafter + server_name = response.get("server_name") + if not isinstance(server_name, six.string_types): + raise KeyLookupError( + "Malformed response from key notary server %s: invalid server_name" + % (perspective_name,) ) - for key_id, key in verify_keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) + + try: + processed_response = yield self._process_perspectives_response( + perspective_name, + perspective_keys, + response, + time_added_ms=time_now_ms, + ) + except KeyLookupError as e: + logger.warning( + "Error processing response from key notary server %s for origin " + "server %s: %s", + perspective_name, + server_name, + e, + ) + # we continue to process the rest of the response + continue + + added_keys.extend( + (server_name, key_id, key) for key_id, key in processed_response.items() + ) + keys.setdefault(server_name, {}).update(processed_response) + + yield self.store.store_server_verify_keys( + perspective_name, time_now_ms, added_keys + ) + + defer.returnValue(keys) + + def _process_perspectives_response( + self, perspective_name, perspective_keys, response, time_added_ms + ): + """Parse a 'Server Keys' structure from the result of a /key/query request + + Checks that the entry is correctly signed by the perspectives server, and then + passes over to process_v2_response + + Args: + perspective_name (str): the name of the notary server that produced this + result + + perspective_keys (dict[str, VerifyKey]): map of key_id->key for the + notary server + + response (dict): the json-decoded Server Keys response object + + time_added_ms (int): the timestamp to record in server_keys_json + + Returns: + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + """ + if ( + u"signatures" not in response + or perspective_name not in response[u"signatures"] + ): + raise KeyLookupError("Response not signed by the notary server") + + verified = False + for key_id in response[u"signatures"][perspective_name]: + if key_id in perspective_keys: + verify_signed_json(response, perspective_name, perspective_keys[key_id]) + verified = True + + if not verified: + raise KeyLookupError( + "Response not signed with a known key: signed with: %r, known keys: %r" + % ( + list(response[u"signatures"][perspective_name].keys()), + list(perspective_keys.keys()), + ) + ) + + return self.process_v2_response( + perspective_name, response, time_added_ms=time_added_ms + ) + + +class ServerKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the origin servers""" + + def __init__(self, hs): + super(ServerKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + + def get_keys(self, keys_to_fetch): + """ + Args: + keys_to_fetch (dict[str, iterable[str]]): + the keys to be fetched. server_name -> key_ids + + Returns: + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + map from server_name -> key_id -> FetchKeyResult + """ + + results = {} + + @defer.inlineCallbacks + def get_key(key_to_fetch_item): + server_name, key_ids = key_to_fetch_item + try: + keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) + results[server_name] = keys + except KeyLookupError as e: + logger.warning( + "Error looking up keys %s from %s: %s", key_ids, server_name, e + ) + except Exception: + logger.exception("Error getting keys %s from %s", key_ids, server_name) + + return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( + lambda _: results + ) + + @defer.inlineCallbacks + def get_server_verify_key_v2_direct(self, server_name, key_ids): + """ + + Args: + server_name (str): + key_ids (iterable[str]): + + Returns: + Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result + + Raises: + KeyLookupError if there was a problem making the lookup + """ + keys = {} # type: dict[str, FetchKeyResult] + + for requested_key_id in key_ids: + # we may have found this key as a side-effect of asking for another. + if requested_key_id in keys: + continue + + time_now_ms = self.clock.time_msec() + try: + response = yield self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, + ) + except (NotRetryingDestination, RequestSendFailed) as e: + raise_from(KeyLookupError("Failed to connect to remote server"), e) + except HttpResponseException as e: + raise_from(KeyLookupError("Remote server returned an error"), e) + + if response["server_name"] != server_name: + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) + ) + + response_keys = yield self.process_v2_response( + from_server=server_name, + response_json=response, + time_added_ms=time_now_ms, + ) + yield self.store.store_server_verify_keys( + server_name, + time_now_ms, + ((server_name, key_id, key) for key_id, key in response_keys.items()), + ) + keys.update(response_keys) + + defer.returnValue(keys) @defer.inlineCallbacks @@ -693,48 +861,25 @@ def _handle_key_deferred(verify_request): SynapseError if there was a problem performing the verification """ server_name = verify_request.server_name - try: - with PreserveLoggingContext(): - _, key_id, verify_key = yield verify_request.deferred - except KeyLookupError as e: - logger.warn( - "Failed to download keys for %s: %s %s", - server_name, type(e).__name__, str(e), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.exception( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, verify_request.key_ids), - Codes.UNAUTHORIZED, - ) + with PreserveLoggingContext(): + _, key_id, verify_key = yield verify_request.key_ready json_object = verify_request.json_object - logger.debug("Got key %s %s:%s for server %s, verifying" % ( - key_id, verify_key.alg, verify_key.version, server_name, - )) try: verify_signed_json(json_object, server_name, verify_key) except SignatureVerifyException as e: logger.debug( "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, verify_key.alg, verify_key.version, + server_name, + verify_key.alg, + verify_key.version, encode_verify_key_base64(verify_key), str(e), ) raise SynapseError( 401, - "Invalid signature for server %s with key %s:%s: %s" % ( - server_name, verify_key.alg, verify_key.version, str(e), - ), + "Invalid signature for server %s with key %s:%s: %s" + % (server_name, verify_key.alg, verify_key.version, str(e)), Codes.UNAUTHORIZED, ) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 12056d5be..1edd19cc1 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -21,6 +21,7 @@ import six from unpaddedbase64 import encode_base64 +from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -335,13 +336,32 @@ class FrozenEventV2(EventBase): return self.__repr__() def __repr__(self): - return "" % ( + return "<%s event_id='%s', type='%s', state_key='%s'>" % ( + self.__class__.__name__, self.event_id, self.get("type", None), self.get("state_key", None), ) +class FrozenEventV3(FrozenEventV2): + """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 + + @property + def event_id(self): + # We have to import this here as otherwise we get an import loop which + # is hard to break. + from synapse.crypto.event_signing import compute_event_reference_hash + + if self._event_id: + return self._event_id + self._event_id = "$" + encode_base64( + compute_event_reference_hash(self)[1], urlsafe=True + ) + return self._event_id + + def room_version_to_event_format(room_version): """Converts a room version string to the event format @@ -350,12 +370,15 @@ def room_version_to_event_format(room_version): Returns: int + + Raises: + UnsupportedRoomVersionError if the room version is unknown """ v = KNOWN_ROOM_VERSIONS.get(room_version) if not v: - # We should have already checked version, so this should not happen - raise RuntimeError("Unrecognized room version %s" % (room_version,)) + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError() return v.event_format @@ -376,6 +399,8 @@ def event_type_from_format_version(format_version): return FrozenEvent elif format_version == EventFormatVersions.V2: return FrozenEventV2 + elif format_version == EventFormatVersions.V3: + return FrozenEventV3 else: raise Exception( "No event format %r" % (format_version,) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index fba27177c..546b6f498 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -18,6 +18,7 @@ import attr from twisted.internet import defer from synapse.api.constants import MAX_DEPTH +from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, @@ -75,6 +76,7 @@ class EventBuilder(object): # someone tries to get them when they don't exist. _state_key = attr.ib(default=None) _redacts = attr.ib(default=None) + _origin_server_ts = attr.ib(default=None) internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({}))) @@ -141,6 +143,9 @@ class EventBuilder(object): if self._redacts is not None: event_dict["redacts"] = self._redacts + if self._origin_server_ts is not None: + event_dict["origin_server_ts"] = self._origin_server_ts + defer.returnValue( create_local_event_from_event_dict( clock=self._clock, @@ -178,9 +183,8 @@ class EventBuilderFactory(object): """ v = KNOWN_ROOM_VERSIONS.get(room_version) if not v: - raise Exception( - "No event format defined for version %r" % (room_version,) - ) + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError() return self.for_room_version(v, key_values) def for_room_version(self, room_version, key_values): @@ -209,6 +213,7 @@ class EventBuilderFactory(object): content=key_values.get("content", {}), unsigned=key_values.get("unsigned", {}), redacts=key_values.get("redacts", None), + origin_server_ts=key_values.get("origin_server_ts", None), ) @@ -245,7 +250,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, event_dict["event_id"] = _create_event_id(clock, hostname) event_dict["origin"] = hostname - event_dict["origin_server_ts"] = time_now + event_dict.setdefault("origin_server_ts", time_now) event_dict.setdefault("unsigned", {}) age = event_dict["unsigned"].pop("age", 0) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 368b5f6ae..fa09c132a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -187,7 +187,9 @@ class EventContext(object): Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + is None, which happens when the associated event is an outlier. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if not self._fetching_state_deferred: @@ -205,7 +207,9 @@ class EventContext(object): Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + is None, which happens when the associated event is an outlier. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if not self._fetching_state_deferred: diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 07fccdd8f..e2d4384de 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -19,7 +19,10 @@ from six import string_types from frozendict import frozendict -from synapse.api.constants import EventTypes +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -311,3 +314,93 @@ def serialize_event(e, time_now_ms, as_client_event=True, d = only_fields(d, only_event_fields) return d + + +class EventClientSerializer(object): + """Serializes events that are to be sent to clients. + + This is used for bundling extra information with any events to be sent to + clients. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.experimental_msc1849_support_enabled = ( + hs.config.experimental_msc1849_support_enabled + ) + + @defer.inlineCallbacks + def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): + """Serializes a single event. + + Args: + event (EventBase) + time_now (int): The current time in milliseconds + bundle_aggregations (bool): Whether to bundle in related events + **kwargs: Arguments to pass to `serialize_event` + + Returns: + Deferred[dict]: The serialized event + """ + # To handle the case of presence events and the like + if not isinstance(event, EventBase): + defer.returnValue(event) + + event_id = event.event_id + serialized_event = serialize_event(event, time_now, **kwargs) + + # If MSC1849 is enabled then we need to look if thre are any relations + # we need to bundle in with the event + if self.experimental_msc1849_support_enabled and bundle_aggregations: + annotations = yield self.store.get_aggregation_groups_for_event( + event_id, + ) + references = yield self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f", + ) + + if annotations.chunk: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.ANNOTATION] = annotations.to_dict() + + if references.chunk: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = yield self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + relations = event.content.get("m.relates_to") + serialized_event["content"] = edit.content.get("m.new_content", {}) + if relations: + serialized_event["content"]["m.relates_to"] = relations + else: + serialized_event["content"].pop("m.relates_to", None) + + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + } + + defer.returnValue(serialized_event) + + def serialize_events(self, events, time_now, **kwargs): + """Serializes multiple events. + + Args: + event (iter[EventBase]) + time_now (int): The current time in milliseconds + **kwargs: Arguments to pass to `serialize_event` + + Returns: + Deferred[list[dict]]: The list of serialized events + """ + return yieldable_gather_results( + self.serialize_event, events, + time_now=time_now, **kwargs + ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 514273c79..711af512b 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -15,8 +15,8 @@ from six import string_types -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership +from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions from synapse.types import EventID, RoomID, UserID @@ -56,6 +56,17 @@ class EventValidator(object): if not isinstance(getattr(event, s), string_types): raise SynapseError(400, "'%s' not a string type" % (s,)) + if event.type == EventTypes.Aliases: + if "aliases" in event.content: + for alias in event.content["aliases"]: + if len(alias) > MAX_ALIAS_LENGTH: + raise SynapseError( + 400, + ("Can't create aliases longer than" + " %d characters" % (MAX_ALIAS_LENGTH,)), + Codes.INVALID_PARAM, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index dfe6b4aa5..4b38f7c75 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -265,11 +265,22 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): ] more_deferreds = keyring.verify_json_objects_for_server([ - (p.sender_domain, p.redacted_pdu_json) + (p.sender_domain, p.redacted_pdu_json, 0) for p in pdus_to_check_sender ]) + def sender_err(e, pdu_to_check): + errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( + pdu_to_check.pdu.event_id, + pdu_to_check.sender_domain, + e.getErrorMessage(), + ) + # XX not really sure if these are the right codes, but they are what + # we've done for ages + raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + for p, d in zip(pdus_to_check_sender, more_deferreds): + d.addErrback(sender_err, p) p.deferreds.append(d) # now let's look for events where the sender's domain is different to the @@ -287,11 +298,22 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): ] more_deferreds = keyring.verify_json_objects_for_server([ - (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json) + (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0) for p in pdus_to_check_event_id ]) + def event_err(e, pdu_to_check): + errmsg = ( + "event id %s: unable to verify signature for event id domain: %s" % ( + pdu_to_check.pdu.event_id, + e.getErrorMessage(), + ) + ) + # XX as above: not really sure if these are the right codes + raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + for p, d in zip(pdus_to_check_event_id, more_deferreds): + d.addErrback(event_err, p) p.deferreds.append(d) # replace lists of deferreds with single Deferreds diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f3fc897a0..70573746d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,7 +17,6 @@ import copy import itertools import logging -import random from six.moves import range @@ -233,7 +232,8 @@ class FederationClient(FederationBase): moving to the next destination. None indicates no timeout. Returns: - Deferred: Results in the requested PDU. + Deferred: Results in the requested PDU, or None if we were unable to find + it. """ # TODO: Rate limit the number of times we try and get the same event. @@ -258,7 +258,12 @@ class FederationClient(FederationBase): destination, event_id, timeout=timeout, ) - logger.debug("transaction_data %r", transaction_data) + logger.debug( + "retrieved event id %s from %s: %r", + event_id, + destination, + transaction_data, + ) pdu_list = [ event_from_pdu_json(p, format_ver, outlier=outlier) @@ -280,6 +285,7 @@ class FederationClient(FederationBase): "Failed to get PDU %s from %s because %s", event_id, destination, e, ) + continue except NotRetryingDestination as e: logger.info(str(e)) continue @@ -326,12 +332,16 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - fetched_events, failed_to_fetch = yield self.get_events( - [destination], room_id, set(state_event_ids + auth_event_ids) + fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( + destination, room_id, set(state_event_ids + auth_event_ids) ) if failed_to_fetch: - logger.warn("Failed to get %r", failed_to_fetch) + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch + ) event_map = { ev.event_id: ev for ev in fetched_events @@ -397,27 +407,20 @@ class FederationClient(FederationBase): defer.returnValue((signed_pdus, signed_auth)) @defer.inlineCallbacks - def get_events(self, destinations, room_id, event_ids, return_local=True): - """Fetch events from some remote destinations, checking if we already - have them. + def get_events_from_store_or_dest(self, destination, room_id, event_ids): + """Fetch events from a remote destination, checking if we already have them. Args: - destinations (list) + destination (str) room_id (str) event_ids (list) - return_local (bool): Whether to include events we already have in - the DB in the returned list of events Returns: Deferred: A deferred resolving to a 2-tuple where the first is a list of events and the second is a list of event ids that we failed to fetch. """ - if return_local: - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - else: - seen_events = yield self.store.have_seen_events(event_ids) - signed_events = [] + seen_events = yield self.store.get_events(event_ids, allow_rejected=True) + signed_events = list(seen_events.values()) failed_to_fetch = set() @@ -428,10 +431,11 @@ class FederationClient(FederationBase): if not missing_events: defer.returnValue((signed_events, failed_to_fetch)) - def random_server_list(): - srvs = list(destinations) - random.shuffle(srvs) - return srvs + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + event_ids, + ) room_version = yield self.store.get_room_version(room_id) @@ -443,7 +447,7 @@ class FederationClient(FederationBase): deferreds = [ run_in_background( self.get_pdu, - destinations=random_server_list(), + destinations=[destination], event_id=e_id, room_version=room_version, ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index df60828db..4c28c1dc3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( IncompatibleRoomVersionError, NotFoundError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.crypto.event_signing import compute_event_signature @@ -198,11 +199,22 @@ class FederationServer(FederationBase): try: room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue + try: + format_ver = room_version_to_event_format(room_version) + except UnsupportedRoomVersionError: + # this can happen if support for a given room version is withdrawn, + # so that we still get events for said room. + logger.info( + "Ignoring PDU for room %s with unknown version %s", + room_id, + room_version, + ) + continue + event = event_from_pdu_json(p, format_ver) pdus_by_room.setdefault(room_id, []).append(event) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index be9921100..fae8bea39 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -33,12 +33,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +# This is defined in the Matrix spec and enforced by the receiver. +MAX_EDUS_PER_TRANSACTION = 100 + logger = logging.getLogger(__name__) sent_edus_counter = Counter( - "synapse_federation_client_sent_edus", - "Total number of EDUs successfully sent", + "synapse_federation_client_sent_edus", "Total number of EDUs successfully sent" ) sent_edus_by_type = Counter( @@ -58,6 +60,7 @@ class PerDestinationQueue(object): destination (str): the server_name of the destination that we are managing transmission for. """ + def __init__(self, hs, transaction_manager, destination): self._server_name = hs.hostname self._clock = hs.get_clock() @@ -68,17 +71,17 @@ class PerDestinationQueue(object): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: list[tuple[EventBase, int]] + self._pending_edus = [] # type: list[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict self._pending_rrs = {} @@ -120,9 +123,7 @@ class PerDestinationQueue(object): Args: states (iterable[UserPresenceState]): presence to send """ - self._pending_presence.update({ - state.user_id: state for state in states - }) + self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() def queue_read_receipt(self, receipt): @@ -132,14 +133,9 @@ class PerDestinationQueue(object): Args: receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued """ - self._pending_rrs.setdefault( - receipt.room_id, {}, - ).setdefault( + self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} - )[receipt.user_id] = { - "event_ids": receipt.event_ids, - "data": receipt.data, - } + )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} def flush_read_receipts_for_room(self, room_id): # if we don't have any read-receipts for this room, it may be that we've already @@ -170,10 +166,7 @@ class PerDestinationQueue(object): # request at which point pending_pdus just keeps growing. # we need application-layer timeouts of some flavour of these # requests - logger.debug( - "TX [%s] Transaction already in progress", - self._destination - ) + logger.debug("TX [%s] Transaction already in progress", self._destination) return logger.debug("TX [%s] Starting transaction loop", self._destination) @@ -197,7 +190,8 @@ class PerDestinationQueue(object): pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( - yield self._get_new_device_messages() + # We have to keep 2 free slots for presence and rr_edus + yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2) ) # BEGIN CRITICAL SECTION @@ -216,19 +210,9 @@ class PerDestinationQueue(object): pending_edus = [] - pending_edus.extend(self._get_rr_edus(force_flush=False)) - # We can only include at most 100 EDUs per transactions - pending_edus.extend(self._pop_pending_edus(100 - len(pending_edus))) - - pending_edus.extend( - self._pending_edus_keyed.values() - ) - - self._pending_edus_keyed = {} - - pending_edus.extend(device_message_edus) - + # rr_edus and pending_presence take at most one slot each + pending_edus.extend(self._get_rr_edus(force_flush=False)) pending_presence = self._pending_presence self._pending_presence = {} if pending_presence: @@ -248,9 +232,23 @@ class PerDestinationQueue(object): ) ) + pending_edus.extend(device_message_edus) + pending_edus.extend( + self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self._pending_edus_keyed + ): + _, val = self._pending_edus_keyed.popitem() + pending_edus.append(val) + if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, len(pending_pdus)) + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), + ) if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", self._destination) @@ -259,7 +257,7 @@ class PerDestinationQueue(object): # if we've decided to send a transaction anyway, and we have room, we # may as well send any pending RRs - if len(pending_edus) < 100: + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: pending_edus.extend(self._get_rr_edus(force_flush=True)) # END CRITICAL SECTION @@ -303,22 +301,25 @@ class PerDestinationQueue(object): except HttpResponseException as e: logger.warning( "TX [%s] Received %d response to transaction: %s", - self._destination, e.code, e, + self._destination, + e.code, + e, ) except RequestSendFailed as e: - logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e) + logger.warning( + "TX [%s] Failed to send transaction: %s", self._destination, e + ) for p, _ in pending_pdus: - logger.info("Failed to send event %s to %s", p.event_id, - self._destination) + logger.info( + "Failed to send event %s to %s", p.event_id, self._destination + ) except Exception: - logger.exception( - "TX [%s] Failed to send transaction", - self._destination, - ) + logger.exception("TX [%s] Failed to send transaction", self._destination) for p, _ in pending_pdus: - logger.info("Failed to send event %s to %s", p.event_id, - self._destination) + logger.info( + "Failed to send event %s to %s", p.event_id, self._destination + ) finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -346,27 +347,13 @@ class PerDestinationQueue(object): return pending_edus @defer.inlineCallbacks - def _get_new_device_messages(self): - last_device_stream_id = self._last_device_stream_id - to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, last_device_stream_id, to_device_stream_id - ) - edus = [ - Edu( - origin=self._server_name, - destination=self._destination, - edu_type="m.direct_to_device", - content=content, - ) - for content in contents - ] - + def _get_new_device_messages(self, limit): last_device_list = self._last_device_list_stream_id + # Will return at most 20 entries now_stream_id, results = yield self._store.get_devices_by_remote( self._destination, last_device_list ) - edus.extend( + edus = [ Edu( origin=self._server_name, destination=self._destination, @@ -374,5 +361,26 @@ class PerDestinationQueue(object): content=content, ) for content in results + ] + + assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + + last_device_stream_id = self._last_device_stream_id + to_device_stream_id = self._store.get_to_device_stream_token() + contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + self._destination, + last_device_stream_id, + to_device_stream_id, + limit - len(edus), ) + edus.extend( + Edu( + origin=self._server_name, + destination=self._destination, + edu_type="m.direct_to_device", + content=content, + ) + for content in contents + ) + defer.returnValue((edus, stream_id, now_stream_id)) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 452599e1a..0db8858cf 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -23,7 +23,11 @@ from twisted.internet import defer import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions -from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX +from synapse.api.urls import ( + FEDERATION_UNSTABLE_PREFIX, + FEDERATION_V1_PREFIX, + FEDERATION_V2_PREFIX, +) from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -63,11 +67,7 @@ class TransportLayerServer(JsonResource): self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( self.clock, - window_size=hs.config.federation_rc_window_size, - sleep_limit=hs.config.federation_rc_sleep_limit, - sleep_msec=hs.config.federation_rc_sleep_delay, - reject_limit=hs.config.federation_rc_reject_limit, - concurrent_requests=hs.config.federation_rc_concurrent, + config=hs.config.rc_federation, ) self.register_servlets() @@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): def __init__(self, hs): + self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() @@ -102,6 +103,7 @@ class Authenticator(object): # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks def authenticate_request(self, request, content): + now = self._clock.time_msec() json_request = { "method": request.method.decode('ascii'), "uri": request.uri.decode('ascii'), @@ -138,7 +140,7 @@ class Authenticator(object): 401, "Missing Authorization headers", Codes.UNAUTHORIZED, ) - yield self.keyring.verify_json_for_server(origin, json_request) + yield self.keyring.verify_json_for_server(origin, json_request, now) logger.info("Request from %s", origin) request.authenticated_entity = origin @@ -716,8 +718,17 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" + def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access): + super(PublicRoomList, self).__init__( + handler, authenticator, ratelimiter, server_name, + ) + self.deny_access = deny_access + @defer.inlineCallbacks def on_GET(self, origin, content, query): + if self.deny_access: + raise FederationDeniedError(origin) + limit = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( @@ -1299,6 +1310,30 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): defer.returnValue((200, new_content)) +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, room_id): + + store = self.handler.hs.get_datastore() + + is_public = yield store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = yield store.get_room_complexity(room_id) + defer.returnValue((200, complexity)) + + FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, @@ -1322,6 +1357,7 @@ FEDERATION_SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, FederationVersionServlet, + RoomComplexityServlet, ) OPENID_SERVLET_CLASSES = ( @@ -1417,6 +1453,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, + deny_access=hs.config.restrict_public_rooms_to_local_users, ).register(resource) if "group_server" in servlet_groups: diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 786149be6..fa6b641ee 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -97,10 +97,11 @@ class GroupAttestationSigning(object): # TODO: We also want to check that *new* attestations that people give # us to store are valid for at least a little while. - if valid_until_ms < self.clock.time_msec(): + now = self.clock.time_msec() + if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server(server_name, attestation) + yield self.keyring.verify_json_for_server(server_name, attestation, now) def create_attestation(self, group_id, user_id): """Create an attestation for the group_id and user_id with default diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index ac09d03ba..dca337ec6 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -90,8 +90,8 @@ class BaseHandler(object): messages_per_second = override.messages_per_second burst_count = override.burst_count else: - messages_per_second = self.hs.config.rc_messages_per_second - burst_count = self.hs.config.rc_message_burst_count + messages_per_second = self.hs.config.rc_message.per_second + burst_count = self.hs.config.rc_message.burst_count allowed, time_allowed = self.ratelimiter.can_do_action( user_id, time_now, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 27bd06df5..a12f9508d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ import string from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, CodeMessageException, @@ -43,8 +43,10 @@ class DirectoryHandler(BaseHandler): self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastore() self.config = hs.config self.enable_room_list_search = hs.config.enable_room_list_search + self.require_membership = hs.config.require_membership_for_aliases self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -83,7 +85,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True): + send_event=True, check_membership=True): """Attempt to create a new alias Args: @@ -93,6 +95,8 @@ class DirectoryHandler(BaseHandler): servers (list[str]|None): List of servers that others servers should try and join via send_event (bool): Whether to send an updated m.room.aliases event + check_membership (bool): Whether to check if the user is in the room + before the alias can be set (if the server's config requires it). Returns: Deferred @@ -100,6 +104,13 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() + if len(room_alias.to_string()) > MAX_ALIAS_LENGTH: + raise SynapseError( + 400, + "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH, + Codes.INVALID_PARAM, + ) + service = requester.app_service if service: if not service.is_interested_in_alias(room_alias.to_string()): @@ -108,6 +119,14 @@ class DirectoryHandler(BaseHandler): " this kind of alias.", errcode=Codes.EXCLUSIVE ) else: + if self.require_membership and check_membership: + rooms_for_user = yield self.store.get_rooms_for_user(user_id) + if room_id not in rooms_for_user: + raise AuthError( + 403, + "You must be in the room to create an alias for it", + ) + if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): raise AuthError( 403, "This user is not permitted to create this alias", diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1b4d8c74a..eb525070c 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -21,7 +21,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase -from synapse.events.utils import serialize_event from synapse.types import UserID from synapse.util.logutils import log_function from synapse.visibility import filter_events_for_client @@ -50,6 +49,7 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @log_function @@ -120,9 +120,12 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() - chunks = [ - serialize_event(e, time_now, as_client_event) for e in events - ] + chunks = yield self._event_serializer.serialize_events( + events, time_now, as_client_event=as_client_event, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_aggregations=False, + ) chunk = { "chunk": chunks, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 068477888..cf4fad7de 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1916,6 +1916,11 @@ class FederationHandler(BaseHandler): event.room_id, latest_event_ids=extrem_ids, ) + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, current_state_ids, + ) + # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ @@ -1932,7 +1937,7 @@ class FederationHandler(BaseHandler): self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: logger.warn( - "Failed current state auth resolution for %r because %s", + "Soft-failing %r because %s", event, e, ) event.internal_metadata.soft_failed = True @@ -2008,15 +2013,44 @@ class FederationHandler(BaseHandler): Args: origin (str): - event (synapse.events.FrozenEvent): + event (synapse.events.EventBase): context (synapse.events.snapshot.EventContext): - auth_events (dict[(str, str)->str]): + auth_events (dict[(str, str)->synapse.events.EventBase]): + Map from (event_type, state_key) to event + + What we expect the event's auth_events to be, based on the event's + position in the dag. I think? maybe?? + + Also NB that this function adds entries to it. + Returns: + defer.Deferred[None] + """ + room_version = yield self.store.get_room_version(event.room_id) + + yield self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + try: + self.auth.check(room_version, event, auth_events=auth_events) + except AuthError as e: + logger.warn("Failed auth resolution for %r because %s", event, e) + raise e + + @defer.inlineCallbacks + def _update_auth_events_and_context_for_auth( + self, origin, event, context, auth_events + ): + """Helper for do_auth. See there for docs. + + Args: + origin (str): + event (synapse.events.EventBase): + context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: defer.Deferred[None] """ - # Check if we have all the auth events. - current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(event.auth_event_ids()) if event.is_state(): @@ -2024,11 +2058,21 @@ class FederationHandler(BaseHandler): else: event_key = None - if event_auth_events - current_state: + # if the event's auth_events refers to events which are not in our + # calculated auth_events, we need to fetch those events from somewhere. + # + # we start by fetching them from the store, and then try calling /event_auth/. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if missing_auth: # TODO: can we use store.have_seen_events here instead? have_events = yield self.store.get_seen_events_with_rejections( - event_auth_events - current_state + missing_auth ) + logger.debug("Got events %s from store", have_events) + missing_auth.difference_update(have_events.keys()) else: have_events = {} @@ -2037,13 +2081,12 @@ class FederationHandler(BaseHandler): for e in auth_events.values() }) - seen_events = set(have_events.keys()) - - missing_auth = event_auth_events - seen_events - current_state - if missing_auth: - logger.info("Missing auth: %s", missing_auth) # If we don't have all the auth events, we need to get them. + logger.info( + "auth_events contains unknown events: %s", + missing_auth, + ) try: remote_auth_chain = yield self.federation_client.get_event_auth( origin, event.room_id, event.event_id @@ -2084,145 +2127,168 @@ class FederationHandler(BaseHandler): have_events = yield self.store.get_seen_events_with_rejections( event.auth_event_ids() ) - seen_events = set(have_events.keys()) except Exception: # FIXME: logger.exception("Failed to get auth chain") + if event.internal_metadata.is_outlier(): + logger.info("Skipping auth_event fetch for outlier") + return + # FIXME: Assumes we have and stored all the state for all the # prev_events - current_state = set(e.event_id for e in auth_events.values()) - different_auth = event_auth_events - current_state + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if not different_auth: + return + + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) room_version = yield self.store.get_room_version(event.room_id) - if different_auth and not event.internal_metadata.is_outlier(): - # Do auth conflict res. - logger.info("Different auth: %s", different_auth) - - different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) - ).addErrback(unwrapFirstError) - - if different_events: - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) - - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.run_in_background( + self.store.get_event, + d, + allow_none=True, + allow_rejected=False, ) + for d in different_auth + if d in have_events and not have_events[d] + ], consumeErrors=True) + ).addErrback(unwrapFirstError) - auth_events.update(new_state) + if different_events: + local_view = dict(auth_events) + remote_view = dict(auth_events) + remote_view.update({ + (d.type, d.state_key): d for d in different_events if d + }) - current_state = set(e.event_id for e in auth_events.values()) - different_auth = event_auth_events - current_state + new_state = yield self.state_handler.resolve_events( + room_version, + [list(local_view.values()), list(remote_view.values())], + event + ) - yield self._update_context_for_auth_events( - event, context, auth_events, event_key, - ) + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) - if different_auth and not event.internal_metadata.is_outlier(): - logger.info("Different auth after resolution: %s", different_auth) + auth_events.update(new_state) - # Only do auth resolution if we have something new to say. - # We can't rove an auth failure. - do_resolution = False + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) - provable = [ - RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR, - ] + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) - for e_id in different_auth: - if e_id in have_events: - if have_events[e_id] in provable: - do_resolution = True - break + if not different_auth: + # we're done + return - if do_resolution: - prev_state_ids = yield context.get_prev_state_ids(self.store) - # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + logger.info( + "auth_events still refers to events which are not in the calculated auth " + "chain after state resolution: %s", + different_auth, + ) - try: - # 2. Get remote difference. - result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, - ) + # Only do auth resolution if we have something new to say. + # We can't prove an auth failure. + do_resolution = False - seen_remotes = yield self.store.have_seen_events( - [e.event_id for e in result["auth_chain"]] - ) + for e_id in different_auth: + if e_id in have_events: + if have_events[e_id] == RejectedReason.NOT_ANCESTOR: + do_resolution = True + break - # 3. Process any remote auth chain events we haven't seen. - for ev in result["auth_chain"]: - if ev.event_id in seen_remotes: - continue + if not do_resolution: + logger.info( + "Skipping auth resolution due to lack of provable rejection reasons" + ) + return - if ev.event_id == event.event_id: - continue + logger.info("Doing auth resolution") - try: - auth_ids = ev.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create - } - ev.internal_metadata.outlier = True + prev_state_ids = yield context.get_prev_state_ids(self.store) - logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id - ) - - yield self._handle_new_event( - origin, ev, auth_events=auth - ) - - if ev.event_id in event_auth_events: - auth_events[(ev.type, ev.state_key)] = ev - except AuthError: - pass - - except Exception: - # FIXME: - logger.exception("Failed to query auth chain") - - # 4. Look at rejects and their proofs. - # TODO. - - yield self._update_context_for_auth_events( - event, context, auth_events, event_key, - ) + # 1. Get what we think is the auth chain. + auth_ids = yield self.auth.compute_auth_events( + event, prev_state_ids + ) + local_auth_chain = yield self.store.get_auth_chain( + auth_ids, include_given=True + ) try: - self.auth.check(room_version, event, auth_events=auth_events) - except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + # 2. Get remote difference. + result = yield self.federation_client.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) + + seen_remotes = yield self.store.have_seen_events( + [e.event_id for e in result["auth_chain"]] + ) + + # 3. Process any remote auth chain events we haven't seen. + for ev in result["auth_chain"]: + if ev.event_id in seen_remotes: + continue + + if ev.event_id == event.event_id: + continue + + try: + auth_ids = ev.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in result["auth_chain"] + if e.event_id in auth_ids + or event.type == EventTypes.Create + } + ev.internal_metadata.outlier = True + + logger.debug( + "do_auth %s different_auth: %s", + event.event_id, e.event_id + ) + + yield self._handle_new_event( + origin, ev, auth_events=auth + ) + + if ev.event_id in event_auth_events: + auth_events[(ev.type, ev.state_key)] = ev + except AuthError: + pass + + except Exception: + # FIXME: + logger.exception("Failed to query auth chain") + + # 4. Look at rejects and their proofs. + # TODO. + + yield self._update_context_for_auth_events( + event, context, auth_events, event_key, + ) @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 7dfae78db..aaee5db0b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -19,7 +19,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig @@ -43,6 +42,7 @@ class InitialSyncHandler(BaseHandler): self.clock = hs.get_clock() self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + self._event_serializer = hs.get_event_client_serializer() def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True, include_archived=False): @@ -138,7 +138,9 @@ class InitialSyncHandler(BaseHandler): d["inviter"] = event.sender invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) + d["invite"] = yield self._event_serializer.serialize_event( + invite_event, time_now, as_client_event, + ) rooms_ret.append(d) @@ -185,18 +187,21 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now=time_now, + as_client_event=as_client_event, + ) + ), "start": start_token.to_string(), "end": end_token.to_string(), } - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() - ] + d["state"] = yield self._event_serializer.serialize_events( + current_state.values(), + time_now=time_now, + as_client_event=as_client_event + ) account_data_events = [] tags = tags_by_room.get(event.room_id) @@ -337,11 +342,15 @@ class InitialSyncHandler(BaseHandler): "membership": membership, "room_id": room_id, "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], + "chunk": (yield self._event_serializer.serialize_events( + messages, time_now, + )), "start": start_token.to_string(), "end": end_token.to_string(), }, - "state": [serialize_event(s, time_now) for s in room_state.values()], + "state": (yield self._event_serializer.serialize_events( + room_state.values(), time_now, + )), "presence": [], "receipts": [], }) @@ -355,10 +364,9 @@ class InitialSyncHandler(BaseHandler): # TODO: These concurrently time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] + state = yield self._event_serializer.serialize_events( + current_state.values(), time_now, + ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -425,7 +433,9 @@ class InitialSyncHandler(BaseHandler): ret = { "room_id": room_id, "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], + "chunk": (yield self._event_serializer.serialize_events( + messages, time_now, + )), "start": start_token.to_string(), "end": end_token.to_string(), }, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 224d34ef3..0b02469ce 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.errors import ( AuthError, Codes, @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder -from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter @@ -57,6 +56,7 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, @@ -164,9 +164,13 @@ class MessageHandler(object): room_state = room_state[membership_event_id] now = self.clock.time_msec() - defer.returnValue( - [serialize_event(c, now) for c in room_state.values()] + events = yield self._event_serializer.serialize_events( + room_state.values(), now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, ) + defer.returnValue(events) @defer.inlineCallbacks def get_joined_members(self, requester, room_id): @@ -228,6 +232,7 @@ class EventCreationHandler(object): self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config + self.require_membership_for_aliases = hs.config.require_membership_for_aliases self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) @@ -336,6 +341,35 @@ class EventCreationHandler(object): prev_events_and_hashes=prev_events_and_hashes, ) + # In an ideal world we wouldn't need the second part of this condition. However, + # this behaviour isn't spec'd yet, meaning we should be able to deactivate this + # behaviour. Another reason is that this code is also evaluated each time a new + # m.room.aliases event is created, which includes hitting a /directory route. + # Therefore not including this condition here would render the similar one in + # synapse.handlers.directory pointless. + if builder.type == EventTypes.Aliases and self.require_membership_for_aliases: + # Ideally we'd do the membership check in event_auth.check(), which + # describes a spec'd algorithm for authenticating events received over + # federation as well as those created locally. As of room v3, aliases events + # can be created by users that are not in the room, therefore we have to + # tolerate them in event_auth.check(). + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event or prev_event.membership != Membership.JOIN: + logger.warning( + ("Attempt to send `m.room.aliases` in room %s by user %s but" + " membership is %s"), + event.room_id, + event.sender, + prev_event.membership if prev_event else None, + ) + + raise AuthError( + 403, + "You must be in the room to create an alias for it", + ) + self.validator.validate_new(event) defer.returnValue((event, context)) @@ -570,6 +604,20 @@ class EventCreationHandler(object): self.validator.validate_new(event) + # If this event is an annotation then we check that that the sender + # can't annotate the same way twice (e.g. stops users from liking an + # event multiple times). + relation = event.content.get("m.relates_to", {}) + if relation.get("rel_type") == RelationTypes.ANNOTATION: + relates_to = relation["event_id"] + aggregation_key = relation["key"] + + already_exists = yield self.store.has_user_annotated_event( + relates_to, event.type, aggregation_key, event.sender, + ) + if already_exists: + raise SynapseError(400, "Can't send same reaction twice") + logger.debug( "Created event %s", event.event_id, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e4fdae926..8f811e24f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -20,7 +20,6 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.events.utils import serialize_event from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -78,6 +77,7 @@ class PaginationHandler(object): self._purges_in_progress_by_room = set() # map from purge id to PurgeStatus self._purges_by_id = {} + self._event_serializer = hs.get_event_client_serializer() def start_purge_history(self, room_id, token, delete_local_events=False): @@ -278,18 +278,22 @@ class PaginationHandler(object): time_now = self.clock.time_msec() chunk = { - "chunk": [ - serialize_event(e, time_now, as_client_event) - for e in events - ], + "chunk": ( + yield self._event_serializer.serialize_events( + events, time_now, + as_client_event=as_client_event, + ) + ), "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), } if state: - chunk["state"] = [ - serialize_event(e, time_now, as_client_event) - for e in state - ] + chunk["state"] = ( + yield self._event_serializer.serialize_events( + state, time_now, + as_client_event=as_client_event, + ) + ) defer.returnValue(chunk) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index bd1285b15..6209858bb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -182,17 +182,27 @@ class PresenceHandler(object): # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. + def run_timeout_handler(): + return run_as_background_process( + "handle_presence_timeouts", self._handle_timeouts + ) + self.clock.call_later( 30, self.clock.looping_call, - self._handle_timeouts, + run_timeout_handler, 5000, ) + def run_persister(): + return run_as_background_process( + "persist_presence_changes", self._persist_unpersisted_changes + ) + self.clock.call_later( 60, self.clock.looping_call, - self._persist_unpersisted_changes, + run_persister, 60 * 1000, ) @@ -229,6 +239,7 @@ class PresenceHandler(object): ) if self.unpersisted_users_changes: + yield self.store.update_presence([ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -240,30 +251,18 @@ class PresenceHandler(object): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ - logger.info( - "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", - len(self.unpersisted_users_changes) - ) - unpersisted = self.unpersisted_users_changes self.unpersisted_users_changes = set() if unpersisted: + logger.info( + "Persisting %d upersisted presence updates", len(unpersisted) + ) yield self.store.update_presence([ self.user_to_current_state[user_id] for user_id in unpersisted ]) - logger.info("Finished _persist_unpersisted_changes") - - @defer.inlineCallbacks - def _update_states_and_catch_exception(self, new_states): - try: - res = yield self._update_states(new_states) - defer.returnValue(res) - except Exception: - logger.exception("Error updating presence") - @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes @@ -338,45 +337,41 @@ class PresenceHandler(object): logger.info("Handling presence timeouts") now = self.clock.time_msec() - try: - with Measure(self.clock, "presence_handle_timeouts"): - # Fetch the list of users that *may* have timed out. Things may have - # changed since the timeout was set, so we won't necessarily have to - # take any action. - users_to_check = set(self.wheel_timer.fetch(now)) + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = set(self.wheel_timer.fetch(now)) - # Check whether the lists of syncing processes from an external - # process have expired. - expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_updated_ms.items() - if now - last_update > EXTERNAL_PROCESS_EXPIRY - ] - for process_id in expired_process_ids: - users_to_check.update( - self.external_process_last_updated_ms.pop(process_id, ()) - ) - self.external_process_last_update.pop(process_id) + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_updated_ms.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_last_updated_ms.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) - states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) - for user_id in users_to_check - ] + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in users_to_check + ] - timers_fired_counter.inc(len(states)) + timers_fired_counter.inc(len(states)) - changes = handle_timeouts( - states, - is_mine_fn=self.is_mine_id, - syncing_user_ids=self.get_currently_syncing_users(), - now=now, - ) + changes = handle_timeouts( + states, + is_mine_fn=self.is_mine_id, + syncing_user_ids=self.get_currently_syncing_users(), + now=now, + ) - run_in_background(self._update_states_and_catch_exception, changes) - except Exception: - logger.exception("Exception in _handle_timeouts loop") + return self._update_states(changes) @defer.inlineCallbacks def bump_presence_active_time(self, user): @@ -828,6 +823,11 @@ class PresenceHandler(object): if typ != EventTypes.Member: continue + if event_id is None: + # state has been deleted, so this is not a join. We only care about + # joins. + continue + event = yield self.store.get_event(event_id) if event.content.get("membership") != Membership.JOIN: # We only care about joins diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index a65c98ff5..a5fc6c5db 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -31,6 +31,9 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +MAX_DISPLAYNAME_LEN = 100 +MAX_AVATAR_URL_LEN = 1000 + class BaseProfileHandler(BaseHandler): """Handles fetching and updating user profile information. @@ -53,6 +56,7 @@ class BaseProfileHandler(BaseHandler): @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) + if self.hs.is_mine(target_user): try: displayname = yield self.store.get_profile_displayname( @@ -161,6 +165,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: + raise SynapseError( + 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ), + ) + if new_displayname == '': new_displayname = None @@ -216,6 +225,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: + raise SynapseError( + 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ), + ) + yield self.store.set_profile_avatar_url( target_user.localpart, new_avatar_url ) @@ -283,6 +297,48 @@ class BaseProfileHandler(BaseHandler): room_id, str(e) ) + @defer.inlineCallbacks + def check_profile_query_allowed(self, target_user, requester=None): + """Checks whether a profile query is allowed. If the + 'require_auth_for_profile_requests' config flag is set to True and a + 'requester' is provided, the query is only allowed if the two users + share a room. + + Args: + target_user (UserID): The owner of the queried profile. + requester (None|UserID): The user querying for the profile. + + Raises: + SynapseError(403): The two users share no room, or ne user couldn't + be found to be in any room the server is in, and therefore the query + is denied. + """ + # Implementation of MSC1301: don't allow looking up profiles if the + # requester isn't in the same room as the target. We expect requester to + # be None when this function is called outside of a profile query, e.g. + # when building a membership event. In this case, we must allow the + # lookup. + if not self.hs.config.require_auth_for_profile_requests or not requester: + return + + try: + requester_rooms = yield self.store.get_rooms_for_user( + requester.to_string() + ) + target_user_rooms = yield self.store.get_rooms_for_user( + target_user.to_string(), + ) + + # Check if the room lists have no elements in common. + if requester_rooms.isdisjoint(target_user_rooms): + raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) + except StoreError as e: + if e.code == 404: + # This likely means that one of the users doesn't exist, + # so we act as if we couldn't find the profile. + raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) + raise + class MasterProfileHandler(BaseProfileHandler): PROFILE_UPDATE_MS = 60 * 1000 diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a51d11a25..9a388ea01 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,7 @@ import logging from twisted.internet import defer from synapse import types -from synapse.api.constants import LoginType +from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.errors import ( AuthError, Codes, @@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id) + if len(user_id) > MAX_USERID_LENGTH: + raise SynapseError( + 400, + "User ID may not be longer than %s characters" % ( + MAX_USERID_LENGTH, + ), + Codes.INVALID_USERNAME + ) + users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: @@ -522,6 +531,8 @@ class RegistrationHandler(BaseHandler): A tuple of (user_id, access_token). Raises: RegistrationError if there was a problem registering. + + NB this is only used in tests. TODO: move it to the test package! """ if localpart is None: raise SynapseError(400, "Request must include user id") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 17628e268..4a17911a8 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -27,7 +27,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError -from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -70,6 +70,7 @@ class RoomCreationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self.config = hs.config # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") @@ -402,7 +403,7 @@ class RoomCreationHandler(BaseHandler): yield directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, servers=(self.hs.hostname, ), - send_event=False, + send_event=False, check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: @@ -475,7 +476,11 @@ class RoomCreationHandler(BaseHandler): if ratelimit: yield self.ratelimit(requester) - room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier) + room_version = config.get( + "room_version", + self.config.default_room_version.identifier, + ) + if not isinstance(room_version, string_types): raise SynapseError( 400, @@ -538,6 +543,7 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, servers=[self.hs.hostname], send_event=False, + check_membership=False, ) preset_config = config.get( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 024d6db27..93ac986c8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +34,8 @@ from synapse.types import RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room +from ._base import BaseHandler + logger = logging.getLogger(__name__) id_server_scheme = "https://" @@ -71,6 +74,12 @@ class RoomMemberHandler(object): self.spam_checker = hs.get_spam_checker() self._server_notices_mxid = self.config.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup + self.allow_per_room_profiles = self.config.allow_per_room_profiles + + # This is only used to get at ratelimit function, and + # maybe_kick_guest_users. It's fine there are multiple of these as + # it doesn't store state. + self.base_handler = BaseHandler(hs) @abc.abstractmethod def _remote_join(self, requester, remote_room_hosts, room_id, user, content): @@ -350,6 +359,13 @@ class RoomMemberHandler(object): # later on. content = dict(content) + if not self.allow_per_room_profiles: + # Strip profile data, knowing that new profile data will be added to the + # event's content in event_creation_handler.create_event() using the target's + # global profile. + content.pop("displayname", None) + content.pop("avatar_url", None) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -703,6 +719,10 @@ class RoomMemberHandler(object): Codes.FORBIDDEN, ) + # We need to rate limit *before* we send out any 3PID invites, so we + # can't just rely on the standard ratelimiting of events. + yield self.base_handler.ratelimit(requester) + invitee = yield self._lookup_3pid( id_server, medium, address ) @@ -924,7 +944,7 @@ class RoomMemberHandler(object): } if self.config.invite_3pid_guest: - guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( + guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest( requester=requester, medium=medium, address=address, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 49c439313..9bba74d6c 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,7 +23,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events.utils import serialize_event from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -36,6 +35,7 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -401,14 +401,16 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = [ - serialize_event(e, time_now) - for e in context["events_before"] - ] - context["events_after"] = [ - serialize_event(e, time_now) - for e in context["events_after"] - ] + context["events_before"] = ( + yield self._event_serializer.serialize_events( + context["events_before"], time_now, + ) + ) + context["events_after"] = ( + yield self._event_serializer.serialize_events( + context["events_after"], time_now, + ) + ) state_results = {} if include_state: @@ -422,14 +424,13 @@ class SearchHandler(BaseHandler): # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong - results = [ - { + results = [] + for e in allowed_events: + results.append({ "rank": rank_map[e.event_id], - "result": serialize_event(e, time_now), + "result": (yield self._event_serializer.serialize_event(e, time_now)), "context": contexts.get(e.event_id, {}), - } - for e in allowed_events - ] + }) rooms_cat_res = { "results": results, @@ -438,10 +439,13 @@ class SearchHandler(BaseHandler): } if state_results: - rooms_cat_res["state"] = { - room_id: [serialize_event(e, time_now) for e in state] - for room_id, state in state_results.items() - } + s = {} + for room_id, state in state_results.items(): + s[room_id] = yield self._event_serializer.serialize_events( + state, time_now, + ) + + rooms_cat_res["state"] = s if room_groups and "room_id" in group_keys: rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py new file mode 100644 index 000000000..0e92b405b --- /dev/null +++ b/synapse/handlers/stats.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.handlers.state_deltas import StateDeltasHandler +from synapse.metrics import event_processing_positions +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +class StatsHandler(StateDeltasHandler): + """Handles keeping the *_stats tables updated with a simple time-series of + information about the users, rooms and media on the server, such that admins + have some idea of who is consuming their resources. + + Heavily derived from UserDirectoryHandler + """ + + def __init__(self, hs): + super(StatsHandler, self).__init__(hs) + self.hs = hs + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.stats_bucket_size = hs.config.stats_bucket_size + + # The current position in the current_state_delta stream + self.pos = None + + # Guard to ensure we only process deltas one at a time + self._is_processing = False + + if hs.config.stats_enabled: + self.notifier.add_replication_callback(self.notify_new_event) + + # We kick this off so that we don't have to wait for a change before + # we start populating stats + self.clock.call_later(0, self.notify_new_event) + + def notify_new_event(self): + """Called when there may be more deltas to process + """ + if not self.hs.config.stats_enabled: + return + + if self._is_processing: + return + + @defer.inlineCallbacks + def process(): + try: + yield self._unsafe_process() + finally: + self._is_processing = False + + self._is_processing = True + run_as_background_process("stats.notify_new_event", process) + + @defer.inlineCallbacks + def _unsafe_process(self): + # If self.pos is None then means we haven't fetched it from DB + if self.pos is None: + self.pos = yield self.store.get_stats_stream_pos() + + # If still None then the initial background update hasn't happened yet + if self.pos is None: + defer.returnValue(None) + + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "stats_delta"): + deltas = yield self.store.get_current_state_deltas(self.pos) + if not deltas: + return + + logger.info("Handling %d state deltas", len(deltas)) + yield self._handle_deltas(deltas) + + self.pos = deltas[-1]["stream_id"] + yield self.store.update_stats_stream_pos(self.pos) + + event_processing_positions.labels("stats").set(self.pos) + + @defer.inlineCallbacks + def _handle_deltas(self, deltas): + """ + Called with the state deltas to process + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + stream_id = delta["stream_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + token = yield self.store.get_earliest_token_for_room_stats(room_id) + + # If the earliest token to begin from is larger than our current + # stream ID, skip processing this delta. + if token is not None and token >= stream_id: + logger.debug( + "Ignoring: %s as earlier than this room's initial ingestion event", + event_id, + ) + continue + + if event_id is None and prev_event_id is None: + # Errr... + continue + + event_content = {} + + if event_id is not None: + event_content = (yield self.store.get_event(event_id)).content or {} + + # quantise time to the nearest bucket + now = yield self.store.get_received_ts(event_id) + now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size + + if typ == EventTypes.Member: + # we could use _get_key_change here but it's a bit inefficient + # given we're not testing for a specific result; might as well + # just grab the prev_membership and membership strings and + # compare them. + prev_event_content = {} + if prev_event_id is not None: + prev_event_content = ( + yield self.store.get_event(prev_event_id) + ).content + + membership = event_content.get("membership", Membership.LEAVE) + prev_membership = prev_event_content.get("membership", Membership.LEAVE) + + if prev_membership == membership: + continue + + if prev_membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", -1 + ) + elif prev_membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", -1 + ) + elif prev_membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", -1 + ) + elif prev_membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", -1 + ) + else: + err = "%s is not a valid prev_membership" % (repr(prev_membership),) + logger.error(err) + raise ValueError(err) + + if membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", +1 + ) + elif membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", +1 + ) + elif membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", +1 + ) + elif membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", +1 + ) + else: + err = "%s is not a valid membership" % (repr(membership),) + logger.error(err) + raise ValueError(err) + + user_id = state_key + if self.is_mine_id(user_id): + # update user_stats as it's one of our users + public = yield self._is_public_room(room_id) + + if membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + -1, + ) + elif membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + +1, + ) + + elif typ == EventTypes.Create: + # Newly created room. Add it with all blank portions. + yield self.store.update_room_state( + room_id, + { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + }, + ) + + elif typ == EventTypes.JoinRules: + yield self.store.update_room_state( + room_id, {"join_rules": event_content.get("join_rule")} + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "join_rule", JoinRules.PUBLIC + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.RoomHistoryVisibility: + yield self.store.update_room_state( + room_id, + {"history_visibility": event_content.get("history_visibility")}, + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "history_visibility", "world_readable" + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.Encryption: + yield self.store.update_room_state( + room_id, {"encryption": event_content.get("algorithm")} + ) + elif typ == EventTypes.Name: + yield self.store.update_room_state( + room_id, {"name": event_content.get("name")} + ) + elif typ == EventTypes.Topic: + yield self.store.update_room_state( + room_id, {"topic": event_content.get("topic")} + ) + elif typ == EventTypes.RoomAvatar: + yield self.store.update_room_state( + room_id, {"avatar": event_content.get("url")} + ) + elif typ == EventTypes.CanonicalAlias: + yield self.store.update_room_state( + room_id, {"canonical_alias": event_content.get("alias")} + ) + + @defer.inlineCallbacks + def update_public_room_stats(self, ts, room_id, is_public): + """ + Increment/decrement a user's number of public rooms when a room they are + in changes to/from public visibility. + + Args: + ts (int): Timestamp in seconds + room_id (str) + is_public (bool) + """ + # For now, blindly iterate over all local users in the room so that + # we can handle the whole problem of copying buckets over as needed + user_ids = yield self.store.get_users_in_room(room_id) + + for user_id in user_ids: + if self.hs.is_mine(UserID.from_string(user_id)): + yield self.store.update_stats_delta( + ts, "user", user_id, "public_rooms", +1 if is_public else -1 + ) + yield self.store.update_stats_delta( + ts, "user", user_id, "private_rooms", -1 if is_public else +1 + ) + + @defer.inlineCallbacks + def _is_public_room(self, room_id): + join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules) + history_visibility = yield self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility + ) + + if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or ( + ( + history_visibility + and history_visibility.content.get("history_visibility") + == "world_readable" + ) + ): + defer.returnValue(True) + else: + defer.returnValue(False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7cf757f65..72997d6d0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -934,7 +934,7 @@ class SyncHandler(object): res = yield self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_users, _, _ = res + newly_joined_rooms, newly_joined_or_invited_users, _, _ = res _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( @@ -943,7 +943,7 @@ class SyncHandler(object): ) if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_users + sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) yield self._generate_sync_entry_for_to_device(sync_result_builder) @@ -951,7 +951,7 @@ class SyncHandler(object): device_lists = yield self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, - newly_joined_users=newly_joined_users, + newly_joined_or_invited_users=newly_joined_or_invited_users, newly_left_rooms=newly_left_rooms, newly_left_users=newly_left_users, ) @@ -1036,7 +1036,8 @@ class SyncHandler(object): @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks def _generate_sync_entry_for_device_list(self, sync_result_builder, - newly_joined_rooms, newly_joined_users, + newly_joined_rooms, + newly_joined_or_invited_users, newly_left_rooms, newly_left_users): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1050,7 +1051,7 @@ class SyncHandler(object): # share a room with? for room_id in newly_joined_rooms: joined_users = yield self.state.get_current_users_in_room(room_id) - newly_joined_users.update(joined_users) + newly_joined_or_invited_users.update(joined_users) for room_id in newly_left_rooms: left_users = yield self.state.get_current_users_in_room(room_id) @@ -1058,7 +1059,7 @@ class SyncHandler(object): # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. - changed.update(newly_joined_users) + changed.update(newly_joined_or_invited_users) if not changed and not newly_left_users: defer.returnValue(DeviceLists( @@ -1176,7 +1177,7 @@ class SyncHandler(object): @defer.inlineCallbacks def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, - newly_joined_users): + newly_joined_or_invited_users): """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1184,8 +1185,9 @@ class SyncHandler(object): sync_result_builder(SyncResultBuilder) newly_joined_rooms(list): List of rooms that the user has joined since the last sync (or empty if an initial sync) - newly_joined_users(list): List of users that have joined rooms - since the last sync (or empty if an initial sync) + newly_joined_or_invited_users(list): List of users that have joined + or been invited to rooms since the last sync (or empty if an initial + sync) """ now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1211,7 +1213,7 @@ class SyncHandler(object): "presence_key", presence_key ) - extra_users_ids = set(newly_joined_users) + extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: users = yield self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) @@ -1243,7 +1245,8 @@ class SyncHandler(object): Returns: Deferred(tuple): Returns a 4-tuple of - `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)` + `(newly_joined_rooms, newly_joined_or_invited_users, + newly_left_rooms, newly_left_users)` """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( @@ -1314,8 +1317,8 @@ class SyncHandler(object): sync_result_builder.invited.extend(invited) - # Now we want to get any newly joined users - newly_joined_users = set() + # Now we want to get any newly joined or invited users + newly_joined_or_invited_users = set() newly_left_users = set() if since_token: for joined_sync in sync_result_builder.joined: @@ -1324,19 +1327,22 @@ class SyncHandler(object): ) for event in it: if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - newly_joined_users.add(event.state_key) + if ( + event.membership == Membership.JOIN or + event.membership == Membership.INVITE + ): + newly_joined_or_invited_users.add(event.state_key) else: prev_content = event.unsigned.get("prev_content", {}) prev_membership = prev_content.get("membership", None) if prev_membership == Membership.JOIN: newly_left_users.add(event.state_key) - newly_left_users -= newly_joined_users + newly_left_users -= newly_joined_or_invited_users defer.returnValue(( newly_joined_rooms, - newly_joined_users, + newly_joined_or_invited_users, newly_left_rooms, newly_left_users, )) @@ -1381,7 +1387,7 @@ class SyncHandler(object): where: room_entries is a list [RoomSyncResultBuilder] invited_rooms is a list [InvitedSyncResult] - newly_joined rooms is a list[str] of room ids + newly_joined_rooms is a list[str] of room ids newly_left_rooms is a list[str] of room ids """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1422,7 +1428,7 @@ class SyncHandler(object): if room_id in sync_result_builder.joined_room_ids and non_joins: # Always include if the user (re)joined the room, especially # important so that device list changes are calculated correctly. - # If there are non join member events, but we are still in the room, + # If there are non-join member events, but we are still in the room, # then the user must have left and joined newly_joined_rooms.append(room_id) diff --git a/synapse/http/client.py b/synapse/http/client.py index ad454f496..77fe68818 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -90,9 +90,32 @@ class IPBlacklistingResolver(object): def resolveHostName(self, recv, hostname, portNumber=0): r = recv() - d = defer.Deferred() addresses = [] + def _callback(): + r.resolutionBegan(None) + + has_bad_ip = False + for i in addresses: + ip_address = IPAddress(i.host) + + if check_against_blacklist( + ip_address, self._ip_whitelist, self._ip_blacklist + ): + logger.info( + "Dropped %s from DNS resolution to %s due to blacklist" % + (ip_address, hostname) + ) + has_bad_ip = True + + # if we have a blacklisted IP, we'd like to raise an error to block the + # request, but all we can really do from here is claim that there were no + # valid results. + if not has_bad_ip: + for i in addresses: + r.addressResolved(i) + r.resolutionComplete() + @provider(IResolutionReceiver) class EndpointReceiver(object): @staticmethod @@ -101,34 +124,16 @@ class IPBlacklistingResolver(object): @staticmethod def addressResolved(address): - ip_address = IPAddress(address.host) - - if check_against_blacklist( - ip_address, self._ip_whitelist, self._ip_blacklist - ): - logger.info( - "Dropped %s from DNS resolution to %s" % (ip_address, hostname) - ) - raise SynapseError(403, "IP address blocked by IP blacklist entry") - addresses.append(address) @staticmethod def resolutionComplete(): - d.callback(addresses) + _callback() self._reactor.nameResolver.resolveHostName( EndpointReceiver, hostname, portNumber=portNumber ) - def _callback(addrs): - r.resolutionBegan(None) - for i in addrs: - r.addressResolved(i) - r.resolutionComplete() - - d.addCallback(_callback) - return r @@ -160,7 +165,8 @@ class BlacklistingAgentWrapper(Agent): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Blocking access to %s because of blacklist" % (ip_address,) + "Blocking access to %s due to blacklist" % + (ip_address,) ) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) @@ -258,9 +264,6 @@ class SimpleHttpClient(object): uri (str): URI to query. data (bytes): Data to send in the request body, if applicable. headers (t.w.http_headers.Headers): Request headers. - - Raises: - SynapseError: If the IP is blacklisted. """ # A small wrapper around self.agent.request() so we can easily attach # counters to it diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1334c630c..b4cbe97b4 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -149,7 +149,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + res.tls_server_name.decode("ascii"), ) # make sure that the Host header is set correctly diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ff63d0b2a..663ea72a7 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -27,9 +27,11 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json +from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers @@ -44,6 +46,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http import QuieterFileBodyProducer +from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable @@ -172,19 +175,44 @@ class MatrixFederationHttpClient(object): self.hs = hs self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname - reactor = hs.get_reactor() + + real_reactor = hs.get_reactor() + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + nameResolver = IPBlacklistingResolver( + real_reactor, None, hs.config.federation_ip_range_blacklist, + ) + + @implementer(IReactorPluggableNameResolver) + class Reactor(object): + def __getattr__(_self, attr): + if attr == "nameResolver": + return nameResolver + else: + return getattr(real_reactor, attr) + + self.reactor = Reactor() self.agent = MatrixFederationAgent( - hs.get_reactor(), + self.reactor, tls_client_options_factory, ) + + # Use a BlacklistingAgentWrapper to prevent circumventing the IP + # blacklist via IP literals in server names + self.agent = BlacklistingAgentWrapper( + self.agent, self.reactor, + ip_blacklist=hs.config.federation_ip_range_blacklist, + ) + self.clock = hs.get_clock() self._store = hs.get_datastore() self.version_string_bytes = hs.version_string.encode('ascii') self.default_timeout = 60 def schedule(x): - reactor.callLater(_EPSILON, x) + self.reactor.callLater(_EPSILON, x) self._cooperator = Cooperator(scheduler=schedule) @@ -257,7 +285,24 @@ class MatrixFederationHttpClient(object): request (MatrixFederationRequest): details of request to be sent timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server). 60s by default. + (including connecting to the server), *for each attempt*. + 60s by default. + + long_retries (bool): whether to use the long retry algorithm. + + The regular retry algorithm makes 4 attempts, with intervals + [0.5s, 1s, 2s]. + + The long retry algorithm makes 11 attempts, with intervals + [4s, 16s, 60s, 60s, ...] + + Both algorithms add -20%/+40% jitter to the retry intervals. + + Note that the above intervals are *in addition* to the time spent + waiting for the request to complete (up to `timeout` ms). + + NB: the long retry algorithm takes over 20 minutes to complete, with + a default timeout of 60s! ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -370,7 +415,7 @@ class MatrixFederationHttpClient(object): request_deferred = timeout_deferred( request_deferred, timeout=_sec_timeout, - reactor=self.hs.get_reactor(), + reactor=self.reactor, ) response = yield request_deferred @@ -397,7 +442,7 @@ class MatrixFederationHttpClient(object): d = timeout_deferred( d, timeout=_sec_timeout, - reactor=self.hs.get_reactor(), + reactor=self.reactor, ) try: @@ -538,10 +583,14 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to use as the request body. - long_retries (bool): A boolean that indicates whether we should - retry for a short or long time. - timeout(int): How long to try (in ms) the destination for before - giving up. None indicates no timeout. + + long_retries (bool): whether to use the long retry algorithm. See + docs on _send_request for details. + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server), *for each attempt*. + self._default_timeout (60s) by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as @@ -586,7 +635,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -599,15 +648,22 @@ class MatrixFederationHttpClient(object): Args: destination (str): The remote server to send the HTTP request to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. - long_retries (bool): A boolean that indicates whether we should - retry for a short or long time. - timeout(int): How long to try (in ms) the destination for before - giving up. None indicates no timeout. + + long_retries (bool): whether to use the long retry algorithm. See + docs on _send_request for details. + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server), *for each attempt*. + self._default_timeout (60s) by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + args (dict): query params Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The @@ -645,7 +701,7 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.hs.get_reactor(), _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response, ) defer.returnValue(body) @@ -658,14 +714,19 @@ class MatrixFederationHttpClient(object): Args: destination (str): The remote server to send the HTTP request to. + path (str): The HTTP path. + args (dict|None): A dictionary used to create query strings, defaults to None. - timeout (int): How long to try (in ms) the destination for before - giving up. None indicates no timeout and that the request will - be retried. + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server), *for each attempt*. + self._default_timeout (60s) by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. @@ -683,10 +744,6 @@ class MatrixFederationHttpClient(object): RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ - logger.debug("get_json args: %s", args) - - logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) - request = MatrixFederationRequest( method="GET", destination=destination, @@ -704,7 +761,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -718,12 +775,18 @@ class MatrixFederationHttpClient(object): destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. - long_retries (bool): A boolean that indicates whether we should - retry for a short or long time. - timeout(int): How long to try (in ms) the destination for before - giving up. None indicates no timeout. + + long_retries (bool): whether to use the long retry algorithm. See + docs on _send_request for details. + + timeout (int|None): number of milliseconds to wait for the response headers + (including connecting to the server), *for each attempt*. + self._default_timeout (60s) by default. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + + args (dict): query params Returns: Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -753,7 +816,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -801,7 +864,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) + d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: logger.warn( diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 528125e73..197c65285 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -55,7 +55,7 @@ def parse_integer_from_args(args, name, default=None, required=False): return int(args[name][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: if required: message = "Missing integer query parameter %r" % (name,) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 8f0682c94..3523a4010 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -261,6 +261,23 @@ BASE_APPEND_OVERRIDE_RULES = [ 'value': True, } ] + }, + { + 'rule_id': 'global/override/.m.rule.tombstone', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.tombstone', + '_id': '_tombstone', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': True, + } + ] } ] diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 779f36dbe..f64baa4d5 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -16,7 +16,12 @@ import logging -from pkg_resources import DistributionNotFound, VersionConflict, get_distribution +from pkg_resources import ( + DistributionNotFound, + Requirement, + VersionConflict, + get_provider, +) logger = logging.getLogger(__name__) @@ -53,7 +58,7 @@ REQUIREMENTS = [ "pyasn1-modules>=0.0.7", "daemonize>=2.3.1", "bcrypt>=3.1.0", - "pillow>=3.1.2", + "pillow>=4.3.0", "sortedcontainers>=1.4.4", "psutil>=2.0.0", "pymacaroons>=0.13.0", @@ -83,7 +88,13 @@ CONDITIONAL_REQUIREMENTS = { # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. - "acme": ["txacme>=0.9.2"], + "acme": [ + "txacme>=0.9.2", + + # txacme depends on eliot. Eliot 1.8.0 is incompatible with + # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 + 'eliot<1.8.0;python_version<"3.5.3"', + ], "saml2": ["pysaml2>=4.5.0"], "systemd": ["systemd-python>=231"], @@ -117,10 +128,10 @@ class DependencyException(Exception): @property def dependencies(self): for i in self.args[0]: - yield '"' + i + '"' + yield "'" + i + "'" -def check_requirements(for_feature=None, _get_distribution=get_distribution): +def check_requirements(for_feature=None): deps_needed = [] errors = [] @@ -131,7 +142,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution): for dependency in reqs: try: - _get_distribution(dependency) + _check_requirement(dependency) except VersionConflict as e: deps_needed.append(dependency) errors.append( @@ -149,7 +160,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution): for dependency in OPTS: try: - _get_distribution(dependency) + _check_requirement(dependency) except VersionConflict as e: deps_needed.append(dependency) errors.append( @@ -167,6 +178,23 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution): raise DependencyException(deps_needed) +def _check_requirement(dependency_string): + """Parses a dependency string, and checks if the specified requirement is installed + + Raises: + VersionConflict if the requirement is installed, but with the the wrong version + DistributionNotFound if nothing is found to provide the requirement + """ + req = Requirement.parse(dependency_string) + + # first check if the markers specify that this requirement needs installing + if req.marker is not None and not req.marker.evaluate(): + # not required for this environment + return + + get_provider(req) + + if __name__ == "__main__": import sys diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b457c5563..a3952506c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.relations import RelationsWorkerStore from synapse.storage.roommember import RoomMemberWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.state import StateGroupWorkerStore @@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore, EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, + RelationsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): @@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore, for row in rows: self.invalidate_caches_for_event( -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, + row.redacts, row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, + data.redacts, data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: @@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore, raise Exception("Unknown events stream row type %s" % (row.type, )) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, backfilled): + etype, state_key, redacts, relates_to, + backfilled): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore, state_key, stream_ordering ) self.get_invited_rooms_for_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8971a6a22..b6ce7a7be 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", ( "type", # str "state_key", # str, optional "redacts", # str, optional + "relates_to", # str, optional )) PresenceStreamRow = namedtuple("PresenceStreamRow", ( "user_id", # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e0f6e2924..f1290d022 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -80,11 +80,12 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + relates_to = attr.ib() # str, optional @attr.s(slots=True, frozen=True) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index a66885d34..e6110ad9b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -13,11 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import synapse.rest.admin from synapse.http.server import JsonResource from synapse.rest.client import versions from synapse.rest.client.v1 import ( - admin, directory, events, initial_sync, @@ -45,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( read_marker, receipts, register, + relations, report_event, room_keys, room_upgrade_rest_servlet, @@ -58,8 +58,14 @@ from synapse.rest.client.v2_alpha import ( class ClientRestResource(JsonResource): - """A resource for version 1 of the matrix client API.""" + """Matrix Client API REST resource. + This gets mounted at various points under /_matrix/client, including: + * /_matrix/client/r0 + * /_matrix/client/api/v1 + * /_matrix/client/unstable + * etc + """ def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) @@ -82,7 +88,6 @@ class ClientRestResource(JsonResource): presence.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) - admin.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) logout.register_servlets(hs, client_resource) @@ -111,3 +116,9 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + relations.register_servlets(hs, client_resource) + + # moving to /_synapse/admin + synapse.rest.admin.register_servlets_for_client_rest_resource( + hs, client_resource + ) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/admin/__init__.py similarity index 78% rename from synapse/rest/client/v1/admin.py rename to synapse/rest/admin/__init__.py index 0a1e233b2..d6c4dcdb1 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import hashlib import hmac import logging import platform +import re from six import text_type from six.moves import http_client @@ -27,39 +28,56 @@ from twisted.internet import defer import synapse from synapse.api.constants import Membership, UserTypes from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.http.server import JsonResource from synapse.http.servlet import ( + RestServlet, assert_params_in_dict, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin +from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.types import UserID, create_requester from synapse.util.versionstring import get_version_string -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class UsersRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/users/(?P[^/]*)") +def historical_admin_path_patterns(path_regex): + """Returns the list of patterns for an admin endpoint, including historical ones + + This is a backwards-compatibility hack. Previously, the Admin API was exposed at + various paths under /_matrix/client. This function returns a list of patterns + matching those paths (as well as the new one), so that existing scripts which rely + on the endpoints being available there are not broken. + + Note that this should only be used for existing endpoints: new ones should just + register for the /_synapse/admin path. + """ + return list( + re.compile(prefix + path_regex) + for prefix in ( + "^/_synapse/admin/v1", + "^/_matrix/client/api/v1/admin", + "^/_matrix/client/unstable/admin", + "^/_matrix/client/r0/admin" + ) + ) + + +class UsersRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/users/(?P[^/]*)") def __init__(self, hs): - super(UsersRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() self.handlers = hs.get_handlers() @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") - - # To allow all users to get the users list - # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") @@ -69,37 +87,30 @@ class UsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class VersionServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/server_version") +class VersionServlet(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), ) - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") - - ret = { + def __init__(self, hs): + self.res = { 'server_version': get_version_string(synapse), 'python_version': platform.python_version(), } - defer.returnValue((200, ret)) + def on_GET(self, request): + return 200, self.res -class UserRegisterServlet(ClientV1RestServlet): +class UserRegisterServlet(RestServlet): """ Attributes: NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted nonces (dict[str, int]): The nonces that we will accept. A dict of nonce to the time it was generated, in int seconds. """ - PATTERNS = client_path_patterns("/admin/register") + PATTERNS = historical_admin_path_patterns("/register") NONCE_TIMEOUT = 60 def __init__(self, hs): - super(UserRegisterServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.reactor = hs.get_reactor() self.nonces = {} @@ -226,11 +237,12 @@ class UserRegisterServlet(ClientV1RestServlet): defer.returnValue((200, result)) -class WhoisRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)") +class WhoisRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/whois/(?P[^/]*)") def __init__(self, hs): - super(WhoisRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -238,10 +250,9 @@ class WhoisRestServlet(ClientV1RestServlet): target_user = UserID.from_string(user_id) requester = yield self.auth.get_user_by_req(request) auth_user = requester.user - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin and target_user != auth_user: - raise AuthError(403, "You are not a server admin") + if target_user != auth_user: + yield assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only whois a local user") @@ -251,20 +262,16 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class PurgeMediaCacheRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/purge_media_cache") +class PurgeMediaCacheRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/purge_media_cache") def __init__(self, hs): self.media_repository = hs.get_media_repository() - super(PurgeMediaCacheRestServlet, self).__init__(hs) + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) logger.info("before_ts: %r", before_ts) @@ -274,9 +281,9 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class PurgeHistoryRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/admin/purge_history/(?P[^/]*)(/(?P[^/]+))?" +class PurgeHistoryRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns( + "/purge_history/(?P[^/]*)(/(?P[^/]+))?" ) def __init__(self, hs): @@ -285,17 +292,13 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): Args: hs (synapse.server.HomeServer) """ - super(PurgeHistoryRestServlet, self).__init__(hs) self.pagination_handler = hs.get_pagination_handler() self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) @@ -371,9 +374,9 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): })) -class PurgeHistoryStatusRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/admin/purge_history_status/(?P[^/]+)" +class PurgeHistoryStatusRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns( + "/purge_history_status/(?P[^/]+)" ) def __init__(self, hs): @@ -382,16 +385,12 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet): Args: hs (synapse.server.HomeServer) """ - super(PurgeHistoryStatusRestServlet, self).__init__(hs) self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, purge_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) purge_status = self.pagination_handler.get_purge_status(purge_id) if purge_status is None: @@ -400,15 +399,16 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet): defer.returnValue((200, purge_status.asdict())) -class DeactivateAccountRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/deactivate/(?P[^/]*)") +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/deactivate/(?P[^/]*)") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__(hs) self._deactivate_account_handler = hs.get_deactivate_account_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, target_user_id): + yield assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -419,11 +419,6 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): ) UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") result = yield self._deactivate_account_handler.deactivate_account( target_user_id, erase, @@ -438,13 +433,13 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): })) -class ShutdownRoomRestServlet(ClientV1RestServlet): +class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed to a new room created by `new_room_user_id` and kicked users will be auto joined to the new room. """ - PATTERNS = client_path_patterns("/admin/shutdown_room/(?P[^/]+)") + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P[^/]+)") DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" @@ -452,19 +447,18 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): ) def __init__(self, hs): - super(ShutdownRoomRestServlet, self).__init__(hs) + self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() self._room_creation_handler = hs.get_room_creation_handler() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id): requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_user_is_admin(self.auth, requester.user) content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) @@ -564,22 +558,20 @@ class ShutdownRoomRestServlet(ClientV1RestServlet): })) -class QuarantineMediaInRoom(ClientV1RestServlet): +class QuarantineMediaInRoom(RestServlet): """Quarantines all media in a room so that no one can download it via this server. """ - PATTERNS = client_path_patterns("/admin/quarantine_media/(?P[^/]+)") + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") def __init__(self, hs): - super(QuarantineMediaInRoom, self).__init__(hs) self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, room_id): requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_user_is_admin(self.auth, requester.user) num_quarantined = yield self.store.quarantine_media_ids_in_room( room_id, requester.user.to_string(), @@ -588,13 +580,12 @@ class QuarantineMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"num_quarantined": num_quarantined})) -class ListMediaInRoom(ClientV1RestServlet): +class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ - PATTERNS = client_path_patterns("/admin/room/(?P[^/]+)/media") + PATTERNS = historical_admin_path_patterns("/room/(?P[^/]+)/media") def __init__(self, hs): - super(ListMediaInRoom, self).__init__(hs) self.store = hs.get_datastore() @defer.inlineCallbacks @@ -609,11 +600,11 @@ class ListMediaInRoom(ClientV1RestServlet): defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) -class ResetPasswordRestServlet(ClientV1RestServlet): +class ResetPasswordRestServlet(RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ + http://localhost:8008/_synapse/admin/v1/reset_password/ @user:to_reset_password?access_token=admin_access_token JsonBodyToSend: { @@ -622,11 +613,10 @@ class ResetPasswordRestServlet(ClientV1RestServlet): Returns: 200 OK with empty object if success otherwise an error. """ - PATTERNS = client_path_patterns("/admin/reset_password/(?P[^/]*)") + PATTERNS = historical_admin_path_patterns("/reset_password/(?P[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(ResetPasswordRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self._set_password_handler = hs.get_set_password_handler() @@ -636,12 +626,10 @@ class ResetPasswordRestServlet(ClientV1RestServlet): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. """ - UserID.from_string(target_user_id) requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + yield assert_user_is_admin(self.auth, requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + UserID.from_string(target_user_id) params = parse_json_object_from_request(request) assert_params_in_dict(params, ["new_password"]) @@ -653,20 +641,19 @@ class ResetPasswordRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) -class GetUsersPaginatedRestServlet(ClientV1RestServlet): +class GetUsersPaginatedRestServlet(RestServlet): """Get request to get specific number of users from Synapse. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + http://localhost:8008/_synapse/admin/v1/users_paginate/ @admin:user?access_token=admin_access_token&start=0&limit=10 Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = client_path_patterns("/admin/users_paginate/(?P[^/]*)") + PATTERNS = historical_admin_path_patterns("/users_paginate/(?P[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(GetUsersPaginatedRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self.handlers = hs.get_handlers() @@ -676,16 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Get request to get specific number of users from Synapse. This needs user to have administrator access in Synapse. """ + yield assert_requester_is_admin(self.auth, request) + target_user = UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") - - # To allow all users to get the users list - # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") @@ -706,7 +686,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): """Post request to get specific number of users from Synapse.. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + http://localhost:8008/_synapse/admin/v1/users_paginate/ @admin:user?access_token=admin_access_token JsonBodyToSend: { @@ -716,12 +696,8 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ + yield assert_requester_is_admin(self.auth, request) UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") order = "name" # order by name in user table params = parse_json_object_from_request(request) @@ -736,21 +712,20 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class SearchUsersRestServlet(ClientV1RestServlet): +class SearchUsersRestServlet(RestServlet): """Get request to search user table for specific users according to search term. This needs user to have administrator access in Synapse. Example: - http://localhost:8008/_matrix/client/api/v1/admin/search_users/ + http://localhost:8008/_synapse/admin/v1/search_users/ @admin:user?access_token=admin_access_token&term=alice Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = client_path_patterns("/admin/search_users/(?P[^/]*)") + PATTERNS = historical_admin_path_patterns("/search_users/(?P[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() - super(SearchUsersRestServlet, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self.handlers = hs.get_handlers() @@ -761,12 +736,9 @@ class SearchUsersRestServlet(ClientV1RestServlet): search term. This needs user to have a administrator access in Synapse. """ - target_user = UserID.from_string(target_user_id) - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + yield assert_requester_is_admin(self.auth, request) - if not is_admin: - raise AuthError(403, "You are not a server admin") + target_user = UserID.from_string(target_user_id) # To allow all users to get the users list # if not is_admin and target_user != auth_user: @@ -784,23 +756,20 @@ class SearchUsersRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) -class DeleteGroupAdminRestServlet(ClientV1RestServlet): +class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ - PATTERNS = client_path_patterns("/admin/delete_group/(?P[^/]*)") + PATTERNS = historical_admin_path_patterns("/delete_group/(?P[^/]*)") def __init__(self, hs): - super(DeleteGroupAdminRestServlet, self).__init__(hs) self.group_server = hs.get_groups_server_handler() self.is_mine_id = hs.is_mine_id + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request, group_id): requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): raise SynapseError(400, "Can only delete local groups") @@ -809,27 +778,21 @@ class DeleteGroupAdminRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) -class AccountValidityRenewServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/admin/account_validity/validity$") +class AccountValidityRenewServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/account_validity/validity$") def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): server """ - super(AccountValidityRenewServlet, self).__init__(hs) - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) - - if not is_admin: - raise AuthError(403, "You are not a server admin") + yield assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) @@ -846,8 +809,33 @@ class AccountValidityRenewServlet(ClientV1RestServlet): } defer.returnValue((200, res)) +######################################################################################## +# +# please don't add more servlets here: this file is already long and unwieldy. Put +# them in separate files within the 'admin' package. +# +######################################################################################## + + +class AdminRestResource(JsonResource): + """The REST resource which gets mounted at /_synapse/admin""" + + def __init__(self, hs): + JsonResource.__init__(self, hs, canonical_json=False) + register_servlets(hs, self) + def register_servlets(hs, http_server): + """ + Register all the admin servlets. + """ + register_servlets_for_client_rest_resource(hs, http_server) + SendServerNoticeServlet(hs).register(http_server) + VersionServlet(hs).register(http_server) + + +def register_servlets_for_client_rest_resource(hs, http_server): + """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) @@ -861,6 +849,7 @@ def register_servlets(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server) UserRegisterServlet(hs).register(http_server) - VersionServlet(hs).register(http_server) DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) + # don't add more things here: new servlets should only be exposed on + # /_synapse/admin so should not go here. Instead register them in AdminRestResource. diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py new file mode 100644 index 000000000..881d67b89 --- /dev/null +++ b/synapse/rest/admin/_base.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import AuthError + + +@defer.inlineCallbacks +def assert_requester_is_admin(auth, request): + """Verify that the requester is an admin user + + WARNING: MAKE SURE YOU YIELD ON THE RESULT! + + Args: + auth (synapse.api.auth.Auth): + request (twisted.web.server.Request): incoming request + + Returns: + Deferred + + Raises: + AuthError if the requester is not an admin + """ + requester = yield auth.get_user_by_req(request) + yield assert_user_is_admin(auth, requester.user) + + +@defer.inlineCallbacks +def assert_user_is_admin(auth, user_id): + """Verify that the given user is an admin user + + WARNING: MAKE SURE YOU YIELD ON THE RESULT! + + Args: + auth (synapse.api.auth.Auth): + user_id (UserID): + + Returns: + Deferred + + Raises: + AuthError if the user is not an admin + """ + + is_admin = yield auth.is_server_admin(user_id) + if not is_admin: + raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py new file mode 100644 index 000000000..ae5aca9da --- /dev/null +++ b/synapse/rest/admin/server_notice_servlet.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import UserID + + +class SendServerNoticeServlet(RestServlet): + """Servlet which will send a server notice to a given user + + POST /_synapse/admin/v1/send_server_notice + { + "user_id": "@target_user:server_name", + "content": { + "msgtype": "m.text", + "body": "This is my message" + } + } + + returns: + + { + "event_id": "$1895723857jgskldgujpious" + } + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + self.snm = hs.get_server_notices_manager() + + def register(self, json_resource): + PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths( + "POST", + (re.compile(PATTERN + "$"), ), + self.on_POST, + ) + json_resource.register_paths( + "PUT", + (re.compile(PATTERN + "/(?P[^/]*)$",), ), + self.on_PUT, + ) + + @defer.inlineCallbacks + def on_POST(self, request, txn_id=None): + yield assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ("user_id", "content")) + event_type = body.get("type", EventTypes.Message) + state_key = body.get("state_key") + + if not self.snm.is_enabled(): + raise SynapseError(400, "Server notices are not enabled on this server") + + user_id = body["user_id"] + UserID.from_string(user_id) + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Server notices can only be sent to local users") + + event = yield self.snm.send_notice( + user_id=body["user_id"], + type=event_type, + state_key=state_key, + event_content=body["content"], + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + def on_PUT(self, request, txn_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, txn_id, + ) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py deleted file mode 100644 index c77d7aba6..000000000 --- a/synapse/rest/client/v1/base.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains base REST classes for constructing client v1 servlets. -""" - -import logging -import re - -from synapse.api.urls import CLIENT_PREFIX -from synapse.http.servlet import RestServlet -from synapse.rest.client.transactions import HttpTransactionCache - -logger = logging.getLogger(__name__) - - -def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True): - """Creates a regex compiled client path with the correct client path - prefix. - - Args: - path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. - Returns: - SRE_Pattern - """ - patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] - if include_in_unstable: - unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) - for release in releases: - new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release) - patterns.append(re.compile("^" + new_prefix + path_regex)) - return patterns - - -class ClientV1RestServlet(RestServlet): - """A base Synapse REST Servlet for the client version 1 API. - """ - - # This subclass was presumably created to allow the auth for the v1 - # protocol version to be different, however this behaviour was removed. - # it may no longer be necessary - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - self.hs = hs - self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0220acf64..0035182bb 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -19,11 +19,10 @@ import logging from twisted.internet import defer from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import RoomAlias -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) @@ -33,13 +32,14 @@ def register_servlets(hs, http_server): ClientAppserviceDirectoryListServer(hs).register(http_server) -class ClientDirectoryServer(ClientV1RestServlet): - PATTERNS = client_path_patterns("/directory/room/(?P[^/]*)$") +class ClientDirectoryServer(RestServlet): + PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__(hs) + super(ClientDirectoryServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_alias): @@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet): defer.returnValue((200, {})) -class ClientDirectoryListServer(ClientV1RestServlet): - PATTERNS = client_path_patterns("/directory/list/room/(?P[^/]*)$") +class ClientDirectoryListServer(RestServlet): + PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__(hs) + super(ClientDirectoryListServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet): defer.returnValue((200, {})) -class ClientAppserviceDirectoryListServer(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$" +class ClientAppserviceDirectoryListServer(RestServlet): + PATTERNS = client_patterns( + "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__(hs) + super(ClientAppserviceDirectoryListServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() def on_PUT(self, request, network_id, room_id): content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index cd9b3bdbd..84ca36270 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -19,22 +19,22 @@ import logging from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.events.utils import serialize_event +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class EventStreamRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/events$") +class EventStreamRestServlet(RestServlet): + PATTERNS = client_patterns("/events$", v1=True) DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__(hs) + super(EventStreamRestServlet, self).__init__() self.event_stream_handler = hs.get_event_stream_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -77,13 +77,14 @@ class EventStreamRestServlet(ClientV1RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. -class EventRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/events/(?P[^/]*)$") +class EventRestServlet(RestServlet): + PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__(hs) + super(EventRestServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request, event_id): @@ -92,7 +93,8 @@ class EventRestServlet(ClientV1RestServlet): time_now = self.clock.time_msec() if event: - defer.returnValue((200, serialize_event(event, time_now))) + event = yield self._event_serializer.serialize_event(event, time_now) + defer.returnValue((200, event)) else: defer.returnValue((404, "Event not found.")) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 3ead75cb7..0fe5f2d79 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -15,19 +15,19 @@ from twisted.internet import defer -from synapse.http.servlet import parse_boolean +from synapse.http.servlet import RestServlet, parse_boolean +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_patterns - # TODO: Needs unit testing -class InitialSyncRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/initialSync$") +class InitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__(hs) + super(InitialSyncRestServlet, self).__init__() self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5180e9eaf..3b6072862 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -29,12 +29,11 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart from synapse.util.msisdn import phone_number_to_msisdn -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) @@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier): } -class LoginRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login$") +class LoginRestServlet(RestServlet): + PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" def __init__(self, hs): - super(LoginRestServlet, self).__init__(hs) + super(LoginRestServlet, self).__init__() + self.hs = hs self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm @@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet): class CasRedirectServlet(RestServlet): - PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) def __init__(self, hs): super(CasRedirectServlet, self).__init__() @@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet): b"redirectUrl": args[b"redirectUrl"][0] }).encode('ascii') hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/api/v1/login/cas/ticket") + b"/_matrix/client/r0/login/cas/ticket") service_param = urllib.parse.urlencode({ b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) }).encode('ascii') @@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet): finish_request(request) -class CasTicketServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) +class CasTicketServlet(RestServlet): + PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__(hs) + super(CasTicketServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) + self._http_client = hs.get_simple_http_client() @defer.inlineCallbacks def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) - http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), "service": self.cas_service_url } try: - body = yield http_client.get_raw(uri, args) + body = yield self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 430c69233..b8064f261 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -17,19 +17,18 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError - -from .base import ClientV1RestServlet, client_path_patterns +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -class LogoutRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/logout$") +class LogoutRestServlet(RestServlet): + PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__(hs) - self._auth = hs.get_auth() + super(LogoutRestServlet, self).__init__() + self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -38,32 +37,25 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - try: - requester = yield self.auth.get_user_by_req(request) - except AuthError: - # this implies the access token has already been deleted. - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + requester = yield self.auth.get_user_by_req(request) + + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = self.auth.get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) else: - if requester.device_id is None: - # the acccess token wasn't associated with a device. - # Just delete the access token - access_token = self._auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) - else: - yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) defer.returnValue((200, {})) -class LogoutAllRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/logout/all$") +class LogoutAllRestServlet(RestServlet): + PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__(hs) + super(LogoutAllRestServlet, self).__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 045d5a20a..e263da3cb 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -23,21 +23,22 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class PresenceStatusRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/presence/(?P[^/]*)/status") +class PresenceStatusRestServlet(RestServlet): + PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__(hs) + super(PresenceStatusRestServlet, self).__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index a23edd8fe..e15d9d82a 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,26 +16,33 @@ """ This module contains REST servlets to do with profile: /profile/ """ from twisted.internet import defer -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_patterns - -class ProfileDisplaynameRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P[^/]*)/displayname") +class ProfileDisplaynameRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__(hs) + super(ProfileDisplaynameRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = yield self.auth.get_user_by_req(request) + requester_user = requester.user + user = UserID.from_string(user_id) - displayname = yield self.profile_handler.get_displayname( - user, - ) + yield self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = yield self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -65,20 +72,28 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): return (200, {}) -class ProfileAvatarURLRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P[^/]*)/avatar_url") +class ProfileAvatarURLRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__(hs) + super(ProfileAvatarURLRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = yield self.auth.get_user_by_req(request) + requester_user = requester.user + user = UserID.from_string(user_id) - avatar_url = yield self.profile_handler.get_avatar_url( - user, - ) + yield self.profile_handler.check_profile_query_allowed(user, requester_user) + + avatar_url = yield self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -107,23 +122,29 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): return (200, {}) -class ProfileRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P[^/]*)") +class ProfileRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__(hs) + super(ProfileRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = yield self.auth.get_user_by_req(request) + requester_user = requester.user + user = UserID.from_string(user_id) - displayname = yield self.profile_handler.get_displayname( - user, - ) - avatar_url = yield self.profile_handler.get_avatar_url( - user, - ) + yield self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = yield self.profile_handler.get_displayname(user) + avatar_url = yield self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 506ec95dd..3d6326fe2 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,22 +21,22 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import parse_json_value_from_request, parse_string +from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from .base import ClientV1RestServlet, client_path_patterns - -class PushRuleRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/(?Ppushrules/.*)$") +class PushRuleRestServlet(RestServlet): + PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") def __init__(self, hs): - super(PushRuleRestServlet, self).__init__(hs) + super(PushRuleRestServlet, self).__init__() + self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 4c07ae7f4..15d860db3 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -26,17 +26,18 @@ from synapse.http.servlet import ( parse_string, ) from synapse.push import PusherConfigException - -from .base import ClientV1RestServlet, client_path_patterns +from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -class PushersRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/pushers$") +class PushersRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__(hs) + super(PushersRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet): return 200, {} -class PushersSetRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/pushers/set$") +class PushersSetRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__(hs) + super(PushersSetRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() @@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ - PATTERNS = client_path_patterns("/pushers/remove$") + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" def __init__(self, hs): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 48da4d557..e8f672c4b 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -26,39 +26,47 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2, serialize_event +from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( + RestServlet, assert_params_in_dict, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class RoomCreateRestServlet(ClientV1RestServlet): +class TransactionRestServlet(RestServlet): + def __init__(self, hs): + super(TransactionRestServlet, self).__init__() + self.txns = HttpTransactionCache(hs) + + +class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): super(RoomCreateRestServlet, self).__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity http_server.register_paths("OPTIONS", - client_path_patterns("/rooms(?:/.*)?$"), + client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS) # define CORS for /createRoom[/txnid] http_server.register_paths("OPTIONS", - client_path_patterns("/createRoom(?:/.*)?$"), + client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS) def on_PUT(self, request, txn_id): @@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(ClientV1RestServlet): +class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() def register(self, http_server): # /room/$roomid/state/$eventtype @@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet): "(?P[^/]*)/(?P[^/]*)$") http_server.register_paths("GET", - client_path_patterns(state_key), + client_patterns(state_key, v1=True), self.on_GET) http_server.register_paths("PUT", - client_path_patterns(state_key), + client_patterns(state_key, v1=True), self.on_PUT) http_server.register_paths("GET", - client_path_patterns(no_state_key), + client_patterns(no_state_key, v1=True), self.on_GET_no_state_key) http_server.register_paths("PUT", - client_path_patterns(no_state_key), + client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key) def on_GET_no_state_key(self, request, room_id, event_type): @@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback -class RoomSendEventRestServlet(ClientV1RestServlet): +class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(ClientV1RestServlet): +class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): super(JoinRoomAliasServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet): # TODO: Needs unit testing -class PublicRoomListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/publicRooms$") +class PublicRoomListRestServlet(TransactionRestServlet): + PATTERNS = client_patterns("/publicRooms$", v1=True) + + def __init__(self, hs): + super(PublicRoomListRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -301,6 +317,12 @@ class PublicRoomListRestServlet(ClientV1RestServlet): try: yield self.auth.get_user_by_req(request, allow_guest=True) except AuthError as e: + # Option to allow servers to require auth when accessing + # /publicRooms via CS API. This is especially helpful in private + # federations. + if self.hs.config.restrict_public_rooms_to_local_users: + raise + # We allow people to not be authed if they're just looking at our # room list, but require auth when we proxy the request. # In both cases we call the auth function, as that has the side @@ -376,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomMemberListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/members$") +class RoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__(hs) + super(RoomMemberListRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -430,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet): # deprecated in favour of /members?membership=join? # except it does custom AS logic and has a simpler return format -class JoinedRoomMemberListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/joined_members$") +class JoinedRoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__(hs) + super(JoinedRoomMemberListRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -451,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): # TODO: Needs better unit testing -class RoomMessageListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/messages$") +class RoomMessageListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__(hs) + super(RoomMessageListRestServlet, self).__init__() self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -469,6 +494,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet): if filter_bytes: filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) event_filter = Filter(json.loads(filter_json)) + if event_filter.filter_json.get("event_format", "client") == "federation": + as_client_event = False else: event_filter = None msgs = yield self.pagination_handler.get_messages( @@ -483,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomStateRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/state$") +class RoomStateRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__(hs) + super(RoomStateRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -503,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomInitialSyncRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/initialSync$") +class RoomInitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__(hs) + super(RoomInitialSyncRestServlet, self).__init__() self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -522,15 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): defer.returnValue((200, content)) -class RoomEventServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P[^/]*)/event/(?P[^/]*)$" +class RoomEventServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomEventServlet, self).__init__(hs) + super(RoomEventServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -539,20 +570,23 @@ class RoomEventServlet(ClientV1RestServlet): time_now = self.clock.time_msec() if event: - defer.returnValue((200, serialize_event(event, time_now))) + event = yield self._event_serializer.serialize_event(event, time_now) + defer.returnValue((200, event)) else: defer.returnValue((404, "Event not found.")) -class RoomEventContextServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P[^/]*)/context/(?P[^/]*)$" +class RoomEventContextServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__(hs) + super(RoomEventContextServlet, self).__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -582,24 +616,27 @@ class RoomEventContextServlet(ClientV1RestServlet): ) time_now = self.clock.time_msec() - results["events_before"] = [ - serialize_event(event, time_now) for event in results["events_before"] - ] - results["event"] = serialize_event(results["event"], time_now) - results["events_after"] = [ - serialize_event(event, time_now) for event in results["events_after"] - ] - results["state"] = [ - serialize_event(event, time_now) for event in results["state"] - ] + results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"], time_now, + ) + results["event"] = yield self._event_serializer.serialize_event( + results["event"], time_now, + ) + results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"], time_now, + ) + results["state"] = yield self._event_serializer.serialize_events( + results["state"], time_now, + ) defer.returnValue((200, results)) -class RoomForgetRestServlet(ClientV1RestServlet): +class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomForgetRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/forget") @@ -626,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomMembershipRestServlet(ClientV1RestServlet): +class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): # /rooms/$roomid/[invite|join|leave] @@ -709,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): ) -class RoomRedactEventRestServlet(ClientV1RestServlet): +class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomRedactEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") @@ -744,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): ) -class RoomTypingRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P[^/]*)/typing/(?P[^/]*)$" +class RoomTypingRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__(hs) + super(RoomTypingRestServlet, self).__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): @@ -785,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) -class SearchRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/search$" - ) +class SearchRestServlet(RestServlet): + PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__(hs) + super(SearchRestServlet, self).__init__() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request): @@ -810,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) -class JoinedRoomsRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/joined_rooms$") +class JoinedRoomsRestServlet(RestServlet): + PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__(hs) + super(JoinedRoomsRestServlet, self).__init__() self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -840,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): """ http_server.register_paths( "POST", - client_path_patterns(regex_string + "$"), + client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", - client_path_patterns(regex_string + "/(?P[^/]*)$"), + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), servlet.on_PUT ) if with_get: http_server.register_paths( "GET", - client_path_patterns(regex_string + "/(?P[^/]*)$"), + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), servlet.on_GET ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 53da905ee..638104921 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -19,11 +19,17 @@ import hmac from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns -class VoipRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/voip/turnServer$") +class VoipRestServlet(RestServlet): + PATTERNS = client_patterns("/voip/turnServer$", v1=True) + + def __init__(self, hs): + super(VoipRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 77434937f..5236d5d56 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -21,14 +21,12 @@ import re from twisted.internet import defer from synapse.api.errors import InteractiveAuthIncompleteError -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.api.urls import CLIENT_API_PREFIX logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,), - v2_alpha=True, - unstable=True): +def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): """Creates a regex compiled client path with the correct client path prefix. @@ -39,13 +37,14 @@ def client_v2_patterns(path_regex, releases=(0,), SRE_Pattern """ patterns = [] - if v2_alpha: - patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) if unstable: - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + unstable_prefix = CLIENT_API_PREFIX + "/unstable" patterns.append(re.compile("^" + unstable_prefix + path_regex)) + if v1: + v1_prefix = CLIENT_API_PREFIX + "/api/v1" + patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: - new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) + new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ee069179f..ca35dc3c8 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,13 +30,13 @@ from synapse.http.servlet import ( from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): super(EmailPasswordRequestTokenRestServlet, self).__init__() @@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): class MsisdnPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") + PATTERNS = client_patterns("/account/password/msisdn/requestToken$") def __init__(self, hs): super(MsisdnPasswordRequestTokenRestServlet, self).__init__() @@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password$") + PATTERNS = client_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/deactivate$") + PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): super(DeactivateAccountRestServlet, self).__init__() @@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): self.hs = hs @@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") + PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") def __init__(self, hs): self.hs = hs @@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid$") + PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() @@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/delete$") + PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() @@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet): class WhoamiRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/whoami$") + PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): super(WhoamiRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f171b8d62..574a6298c 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) @@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P[^/]*)" "/rooms/(?P[^/]*)" "/account_data/(?P[^/]*)" diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index fc8dbeb61..55c4ed566 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): - PATTERNS = client_v2_patterns("/account_validity/renew$") + PATTERNS = client_patterns("/account_validity/renew$") SUCCESS_HTML = b"Your account has been successfully renewed." def __init__(self, hs): @@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet): - PATTERNS = client_v2_patterns("/account_validity/send_mail$") + PATTERNS = client_patterns("/account_validity/send_mail$") def __init__(self, hs): """ diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index ac035c773..8dfe5cba0 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -19,11 +19,11 @@ from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import SynapseError -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +from synapse.api.urls import CLIENT_API_PREFIX from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERNS = client_v2_patterns(r"/auth/(?P[\w\.]*)/fallback/web") + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() @@ -139,8 +139,8 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = RECAPTCHA_TEMPLATE % { 'session': session, - 'myurl': "%s/auth/%s/fallback/web" % ( - CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + 'myurl': "%s/r0/auth/%s/fallback/web" % ( + CLIENT_API_PREFIX, LoginType.RECAPTCHA ), 'sitekey': self.hs.config.recaptcha_public_key, } @@ -159,8 +159,8 @@ class AuthRestServlet(RestServlet): self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/auth/%s/fallback/web" % ( - CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS + 'myurl': "%s/r0/auth/%s/fallback/web" % ( + CLIENT_API_PREFIX, LoginType.TERMS ), } html_bytes = html.encode("utf8") @@ -203,8 +203,8 @@ class AuthRestServlet(RestServlet): else: html = RECAPTCHA_TEMPLATE % { 'session': session, - 'myurl': "%s/auth/%s/fallback/web" % ( - CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA + 'myurl': "%s/r0/auth/%s/fallback/web" % ( + CLIENT_API_PREFIX, LoginType.RECAPTCHA ), 'sitekey': self.hs.config.recaptcha_public_key, } @@ -240,8 +240,8 @@ class AuthRestServlet(RestServlet): self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/auth/%s/fallback/web" % ( - CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS + 'myurl': "%s/r0/auth/%s/fallback/web" % ( + CLIENT_API_PREFIX, LoginType.TERMS ), } html_bytes = html.encode("utf8") diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index a868d0609..fc7e2f4dd 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -16,10 +16,10 @@ import logging from twisted.internet import defer -from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) class CapabilitiesRestServlet(RestServlet): """End point to expose the capabilities of the server.""" - PATTERNS = client_v2_patterns("/capabilities$") + PATTERNS = client_patterns("/capabilities$") def __init__(self, hs): """ @@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet): """ super(CapabilitiesRestServlet, self).__init__() self.hs = hs + self.config = hs.config self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet): response = { "capabilities": { "m.room_versions": { - "default": DEFAULT_ROOM_VERSION.identifier, + "default": self.config.default_room_version.identifier, "available": { v.identifier: v.disposition for v in KNOWN_ROOM_VERSIONS.values() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 9b75bb137..78665304a 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -24,13 +24,13 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class DevicesRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) + PATTERNS = client_patterns("/devices$") def __init__(self, hs): """ @@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False) + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet): class DeviceRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", v2_alpha=False) + PATTERNS = client_patterns("/devices/(?P[^/]*)$") def __init__(self, hs): """ diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index ae8672887..65db48c3c 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID -from ._base import client_v2_patterns, set_timeline_upper_limit +from ._base import client_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") + PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") def __init__(self, hs): super(GetFilterRestServlet, self).__init__() @@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user/(?P[^/]*)/filter") + PATTERNS = client_patterns("/user/(?P[^/]*)/filter") def __init__(self, hs): super(CreateFilterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 21e02c07c..d082385ec 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class GroupServlet(RestServlet): """Get the group profile """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/profile$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): super(GroupServlet, self).__init__() @@ -65,7 +65,7 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): """Get the full group summary """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/summary$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): super(GroupSummaryServlet, self).__init__() @@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" "/rooms/(?P[^/]*)$" @@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/(?P[^/]+)$" ) @@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/$" ) @@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/roles/(?P[^/]+)$" ) @@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/roles/$" ) @@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" "/users/(?P[^/]*)$" @@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/rooms$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): super(GroupRoomServlet, self).__init__() @@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): """Get all users in a group """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/users$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): super(GroupUsersServlet, self).__init__() @@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/invited_users$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): super(GroupInvitedUsersServlet, self).__init__() @@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ - PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): super(GroupSettingJoinPolicyServlet, self).__init__() @@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): """Create a group """ - PATTERNS = client_v2_patterns("/create_group$") + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): super(GroupCreateServlet, self).__init__() @@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" ) @@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" "/config/(?P[^/]*)$" ) @@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" ) @@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" ) @@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/self/leave$" ) @@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/self/join$" ) @@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/self/accept_invite$" ) @@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P[^/]*)/self/update_publicity$" ) @@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/publicised_groups/(?P[^/]*)$" ) @@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/publicised_groups$" ) @@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/joined_groups$" ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 8486086b5..4cbfbf563 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -26,7 +26,7 @@ from synapse.http.servlet import ( ) from synapse.types import StreamToken -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P[^/]+))?$") + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") def __init__(self, hs): """ @@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns("/keys/query$") + PATTERNS = client_patterns("/keys/query$") def __init__(self, hs): """ @@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns("/keys/changes$") + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): """ @@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns("/keys/claim$") + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 2a6ea3df5..53e666989 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -17,25 +17,23 @@ import logging from twisted.internet import defer -from synapse.events.utils import ( - format_event_for_client_v2_without_room_id, - serialize_event, -) +from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$") + PATTERNS = client_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request): @@ -69,11 +67,11 @@ class NotificationsServlet(RestServlet): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": serialize_event( + "event": (yield self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, - ), + )), } if pa["room_id"] not in receipts_by_room: diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index 01c90aa2a..bb927d9f9 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -22,7 +22,7 @@ from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P[^/]*)/openid/request_token" ) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index a6e582a5a..f4bd0d077 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -19,13 +19,13 @@ from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReadMarkerRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/rooms/(?P[^/]*)/read_markers$") + PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") def __init__(self, hs): super(ReadMarkerRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index de370cac4..fa12ac3e4 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -20,13 +20,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/rooms/(?P[^/]*)" "/receipt/(?P[^/]*)" "/(?P[^/]*)$" diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index dc3e265bc..79c085408 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -31,6 +31,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, @@ -42,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler # We ought to be using hmac.compare_digest() but on older pythons it doesn't # exist. It's a _really minor_ security flaw to use plain string comparison @@ -59,7 +60,7 @@ logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/email/requestToken$") + PATTERNS = client_patterns("/register/email/requestToken$") def __init__(self, hs): """ @@ -97,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") + PATTERNS = client_patterns("/register/msisdn/requestToken$") def __init__(self, hs): """ @@ -141,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/available") + PATTERNS = client_patterns("/register/available") def __init__(self, hs): """ @@ -153,16 +154,18 @@ class UsernameAvailabilityRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( hs.get_clock(), - # Time window of 2s - window_size=2000, - # Artificially delay requests if rate > sleep_limit/window_size - sleep_limit=1, - # Amount of artificial delay to apply - sleep_msec=1000, - # Error with 429 if more than reject_limit requests are queued - reject_limit=1, - # Allow 1 request at a time - concurrent_requests=1, + FederationRateLimitConfig( + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ) ) @defer.inlineCallbacks @@ -179,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet): class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register$") + PATTERNS = client_patterns("/register$") def __init__(self, hs): """ @@ -345,18 +348,22 @@ class RegisterRestServlet(RestServlet): if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: - flows.extend([[LoginType.RECAPTCHA]]) + # Also add a dummy flow here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) # only support the email-only flow if we don't require MSISDN 3PIDs if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) if show_msisdn: # only support the MSISDN-only flow if we don't require email 3PIDs if not require_email: - flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], + [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) else: # only support 3PIDless registration if no 3PIDs are required @@ -379,7 +386,15 @@ class RegisterRestServlet(RestServlet): if self.hs.config.user_consent_at_registration: new_flows = [] for flow in flows: - flow.append(LoginType.TERMS) + inserted = False + # m.login.terms should go near the end but before msisdn or email auth + for i, stage in enumerate(flow): + if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: + flow.insert(i, LoginType.TERMS) + inserted = True + break + if not inserted: + flow.append(LoginType.TERMS) flows.extend(new_flows) auth_result, params, session_id = yield self.auth_handler.check_auth( @@ -391,13 +406,6 @@ class RegisterRestServlet(RestServlet): # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - # - # Also check that we're not trying to register a 3pid that's already - # been registered. - # - # This has probably happened in /register/email/requestToken as well, - # but if a user hits this endpoint twice then clicks on each link from - # the two activation emails, they would register the same 3pid twice. if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: @@ -413,17 +421,6 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existingUid = yield self.store.get_user_id_by_threepid( - medium, address, - ) - - if existingUid is not None: - raise SynapseError( - 400, - "%s is already in use" % medium, - Codes.THREEPID_IN_USE, - ) - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", @@ -446,6 +443,28 @@ class RegisterRestServlet(RestServlet): if auth_result: threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + # Also check that we're not trying to register a 3pid that's already + # been registered. + # + # This has probably happened in /register/email/requestToken as well, + # but if a user hits this endpoint twice then clicks on each link from + # the two activation emails, they would register the same 3pid twice. + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]['medium'] + address = auth_result[login_type]['address'] + + existingUid = yield self.store.get_user_id_by_threepid( + medium, address, + ) + + if existingUid is not None: + raise SynapseError( + 400, + "%s is already in use" % medium, + Codes.THREEPID_IN_USE, + ) + (registered_user_id, _) = yield self.registration_handler.register( localpart=desired_username, password=new_password, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py new file mode 100644 index 000000000..f8f8742bd --- /dev/null +++ b/synapse/rest/client/v2_alpha/relations.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P[^/]*)/send_relation" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ) + + def __init__(self, hs): + super(RelationSendServlet, self).__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + self.txns = HttpTransactionCache(hs) + + def register(self, http_server): + http_server.register_paths( + "POST", + client_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + ) + http_server.register_paths( + "PUT", + client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), + self.on_PUT, + ) + + def on_PUT(self, request, *args, **kwargs): + return self.txns.fetch_or_execute_request( + request, self.on_PUT_or_POST, request, *args, **kwargs + ) + + @defer.inlineCallbacks + def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + event = yield self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/relations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + res = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, res.to_dict())) + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationGroupPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 95d2a71ec..10198662a 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -27,13 +27,13 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/rooms/(?P[^/]*)/report/(?P[^/]*)$" ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 220a0de30..87779645f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -24,13 +24,13 @@ from synapse.http.servlet import ( parse_string, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class RoomKeysServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" ) @@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/version$" ) @@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/version(/(?P[^/]+))?$" ) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 3db7ff8d1..c621a90fb 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -25,7 +25,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -47,10 +47,9 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( # /rooms/$roomid/upgrade "/rooms/(?P[^/]*)/upgrade$", - v2_alpha=False, ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index a9e9a47a0..120a71336 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -21,15 +21,14 @@ from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request from synapse.rest.client.transactions import HttpTransactionCache -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/sendToDevice/(?P[^/]*)/(?P[^/]*)$", - v2_alpha=False ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 39d157a44..148fc6c98 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -26,14 +26,13 @@ from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, - serialize_event, ) from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken -from ._base import client_v2_patterns, set_timeline_upper_limit +from ._base import client_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ class SyncRestServlet(RestServlet): } """ - PATTERNS = client_v2_patterns("/sync$") + PATTERNS = client_patterns("/sync$") ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) def __init__(self, hs): @@ -86,6 +85,7 @@ class SyncRestServlet(RestServlet): self.filtering = hs.get_filtering() self.presence_handler = hs.get_presence_handler() self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request): @@ -168,14 +168,14 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = self.encode_response( + response_content = yield self.encode_response( time_now, sync_result, requester.access_token_id, filter ) defer.returnValue((200, response_content)) - @staticmethod - def encode_response(time_now, sync_result, access_token_id, filter): + @defer.inlineCallbacks + def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == 'client': event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == 'federation': @@ -183,24 +183,24 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format, )) - joined = SyncRestServlet.encode_joined( + joined = yield self.encode_joined( sync_result.joined, time_now, access_token_id, filter.event_fields, event_formatter, ) - invited = SyncRestServlet.encode_invited( + invited = yield self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter, ) - archived = SyncRestServlet.encode_archived( + archived = yield self.encode_archived( sync_result.archived, time_now, access_token_id, filter.event_fields, event_formatter, ) - return { + defer.returnValue({ "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, "device_lists": { @@ -222,7 +222,7 @@ class SyncRestServlet(RestServlet): }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), - } + }) @staticmethod def encode_presence(events, time_now): @@ -239,8 +239,8 @@ class SyncRestServlet(RestServlet): ] } - @staticmethod - def encode_joined(rooms, time_now, token_id, event_fields, event_formatter): + @defer.inlineCallbacks + def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): """ Encode the joined rooms in a sync result @@ -261,15 +261,15 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = SyncRestServlet.encode_room( + joined[room.room_id] = yield self.encode_room( room, time_now, token_id, joined=True, only_fields=event_fields, event_formatter=event_formatter, ) - return joined + defer.returnValue(joined) - @staticmethod - def encode_invited(rooms, time_now, token_id, event_formatter): + @defer.inlineCallbacks + def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -289,7 +289,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = serialize_event( + invite = yield self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, event_format=event_formatter, is_invite=True, @@ -302,10 +302,10 @@ class SyncRestServlet(RestServlet): "invite_state": {"events": invited_state} } - return invited + defer.returnValue(invited) - @staticmethod - def encode_archived(rooms, time_now, token_id, event_fields, event_formatter): + @defer.inlineCallbacks + def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): """ Encode the archived rooms in a sync result @@ -326,17 +326,17 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = SyncRestServlet.encode_room( + joined[room.room_id] = yield self.encode_room( room, time_now, token_id, joined=False, only_fields=event_fields, event_formatter=event_formatter, ) - return joined + defer.returnValue(joined) - @staticmethod + @defer.inlineCallbacks def encode_room( - room, time_now, token_id, joined, + self, room, time_now, token_id, joined, only_fields, event_formatter, ): """ @@ -355,9 +355,13 @@ class SyncRestServlet(RestServlet): Returns: dict[str, object]: the room, encoded in our response format """ - def serialize(event): - return serialize_event( - event, time_now, token_id=token_id, + def serialize(events): + return self._event_serializer.serialize_events( + events, time_now=time_now, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_aggregations=False, + token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, ) @@ -376,8 +380,8 @@ class SyncRestServlet(RestServlet): event.event_id, room.room_id, event.room_id, ) - serialized_state = [serialize(e) for e in state_events] - serialized_timeline = [serialize(e) for e in timeline_events] + serialized_state = yield serialize(state_events) + serialized_timeline = yield serialize(timeline_events) account_data = room.account_data @@ -397,7 +401,7 @@ class SyncRestServlet(RestServlet): result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - return result + defer.returnValue(result) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 4fea614e9..ebff7cff4 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" ) @@ -54,7 +54,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index b9b5d0767..e7a987466 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -21,13 +21,13 @@ from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols") + PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P[^/]+)$") + PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P[^/]+))?$") + PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P[^/]+))?$") + PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6e76b9e9c..6c366142e 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns class TokenRefreshRestServlet(RestServlet): @@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ - PATTERNS = client_v2_patterns("/tokenrefresh") + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 36b02de37..69e4efc47 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -20,13 +20,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class UserDirectorySearchRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user_directory/search$") + PATTERNS = client_patterns("/user_directory/search$") def __init__(self, hs): """ diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index eb8782aa6..8a730bbc3 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -20,7 +20,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.servlet import parse_integer, parse_json_object_from_request @@ -89,7 +89,7 @@ class RemoteKey(Resource): isLeaf = True def __init__(self, hs): - self.keyring = hs.get_keyring() + self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -215,15 +215,7 @@ class RemoteKey(Resource): json_results.add(bytes(result["key_json"])) if cache_misses and query_remote_on_cache_miss: - for server_name, key_ids in cache_misses.items(): - try: - yield self.keyring.get_server_verify_key_v2_direct( - server_name, key_ids - ) - except KeyLookupError as e: - logger.info("Failed to fetch key: %s", e) - except Exception: - logger.exception("Failed to get key for %r", server_name) + yield self.fetcher.get_keys(cache_misses) yield self.query_keys( request, query, query_remote_on_cache_miss=False ) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index e2b5df701..2dcc8f74d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -191,6 +191,10 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam # in that case. logger.warning("Failed to write to consumer: %s %s", type(e), e) + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + finish_request(request) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bdffa9780..856967735 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -444,6 +444,9 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": @@ -578,6 +581,12 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = yield logcontext.defer_to_thread( + self.hs.get_reactor(), + thumbnailer.transpose + ) + # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ba3ab1d37..acf87709f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -31,6 +31,7 @@ from six.moves import urllib_parse as urlparse from canonicaljson import json from twisted.internet import defer +from twisted.internet.error import DNSLookupError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -328,9 +329,18 @@ class PreviewUrlResource(Resource): # handler will return a SynapseError to the client instead of # blank data or a 500. raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, "DNS resolution failure during URL preview generation", + Codes.UNKNOWN + ) except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) + raise SynapseError( 500, "Failed to download content: %s" % ( traceback.format_exception_only(sys.exc_info()[0], e), diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 5aa03031f..d90cbfb56 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider): """ def __init__(self, hs, config): + self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 5305e9175..35a750923 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -56,8 +56,8 @@ class ThumbnailResource(Resource): def _async_render_GET(self, request): set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) - width = parse_integer(request, "width") - height = parse_integer(request, "height") + width = parse_integer(request, "width", required=True) + height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") m_type = parse_string(request, "type", "image/png") diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a4b26c258..3efd0d80f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -20,6 +20,17 @@ import PIL.Image as Image logger = logging.getLogger(__name__) +EXIF_ORIENTATION_TAG = 0x0112 +EXIF_TRANSPOSE_MAPPINGS = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90 +} + class Thumbnailer(object): @@ -31,6 +42,30 @@ class Thumbnailer(object): def __init__(self, input_path): self.image = Image.open(input_path) self.width, self.height = self.image.size + self.transpose_method = None + try: + # We don't use ImageOps.exif_transpose since it crashes with big EXIF + image_exif = self.image._getexif() + if image_exif is not None: + image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) + self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) + except Exception as e: + # A lot of parsing errors can happen when parsing EXIF + logger.info("Error parsing image EXIF information: %s", e) + + def transpose(self): + """Transpose the image using its EXIF Orientation tag + + Returns: + Tuple[int, int]: (width, height) containing the new image size in pixels. + """ + if self.transpose_method is not None: + self.image = self.image.transpose(self.transpose_method) + self.width, self.height = self.image.size + self.transpose_method = None + # We don't need EXIF any more + self.image.info["exif"] = None + return self.image.size def aspect(self, max_width, max_height): """Calculate the largest size that preserves aspect ratio which diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index ab901e63f..a7fa4f39a 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -68,6 +68,6 @@ class WellKnownResource(Resource): request.setHeader(b"Content-Type", b"text/plain") return b'.well-known not available' - logger.error("returning: %s", r) + logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") return json.dumps(r).encode("utf-8") diff --git a/synapse/server.py b/synapse/server.py index 8c30ac2fa..9229a68a8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -35,6 +35,7 @@ from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker +from synapse.events.utils import EventClientSerializer from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, @@ -71,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.user_directory import UserDirectoryHandler @@ -138,6 +140,7 @@ class HomeServer(object): 'acme_handler', 'auth_handler', 'device_handler', + 'stats_handler', 'e2e_keys_handler', 'e2e_room_keys_handler', 'event_handler', @@ -185,10 +188,12 @@ class HomeServer(object): 'sendmail', 'registration_handler', 'account_validity_handler', + 'event_client_serializer', ] REQUIRED_ON_MASTER_STARTUP = [ "user_directory_handler", + "stats_handler" ] # This is overridden in derived application classes @@ -472,6 +477,9 @@ class HomeServer(object): def build_secrets(self): return Secrets() + def build_stats_handler(self): + return StatsHandler(self) + def build_spam_checker(self): return SpamChecker(self) @@ -511,6 +519,9 @@ class HomeServer(object): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_event_client_serializer(self): + return EventClientSerializer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3ba3a967c..9583e82d5 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -18,7 +18,6 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage - class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c432041b4..71316f7d0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,6 +36,7 @@ from .engines import PostgresEngine from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events import EventsStore +from .events_bg_updates import EventsBackgroundUpdatesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -49,11 +50,13 @@ from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore from .rejections import RejectionsStore +from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore from .signatures import SignatureStore from .state import StateStore +from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore @@ -64,6 +67,7 @@ logger = logging.getLogger(__name__) class DataStore( + EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RegistrationStore, @@ -99,6 +103,8 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + StatsStore, + RelationsStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 983ce026e..52891bb9e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +16,7 @@ # limitations under the License. import itertools import logging +import random import sys import threading import time @@ -227,6 +230,8 @@ class SQLBaseStore(object): # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + self._account_validity = self.hs.config.account_validity + # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -243,6 +248,16 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + self.rand = random.SystemRandom() + + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -275,6 +290,67 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL;" + ) + txn.execute(sql, []) + + res = self.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, + user["name"], + use_delta=True, + ) + + yield self.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + }, + ) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -1203,7 +1279,8 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues), ) - return txn.execute(sql, list(keyvalues.values())) + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount def _simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( @@ -1222,9 +1299,12 @@ class SQLBaseStore(object): column : column name to test for inclusion against `iterable` iterable : list keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted """ if not iterable: - return + return 0 sql = "DELETE FROM %s" % table @@ -1239,7 +1319,9 @@ class SQLBaseStore(object): if clauses: sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - return txn.execute(sql, values) + txn.execute(sql, values) + + return txn.rowcount def _get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 6092f600b..eb329ebd8 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = json.loads(entry["event_ids"]) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue( AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index fed4ea361..9b0a99cb4 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -118,7 +118,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): defer.returnValue(count) def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit=100 + self, destination, last_stream_id, current_stream_id, limit ): """ Args: diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 956f87657..09e39c2c2 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ return self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self._get_events) + ).addCallback(self.get_events_as_list) def get_auth_chain_ids(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list, limit, ) - .addCallback(self._get_events) + .addCallback(self.get_events_as_list) .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) ) @@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self._get_events(ids) + events = yield self.get_events_as_list(ids) defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7a7f841c6..f9162be9b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -219,41 +220,11 @@ class EventsStore( EventsWorkerStore, BackgroundUpdateStore, ): - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts - ) - self.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, - self._background_reindex_fields_sender, - ) - - self.register_background_index_update( - "event_contains_url_index", - index_name="event_contains_url_index", - table="events", - columns=["room_id", "topological_ordering", "stream_ordering"], - where_clause="contains_url = true AND outlier = false", - ) - - # an event_id index on event_search is useful for the purge_history - # api. Plus it means we get to enforce some integrity with a UNIQUE - # clause - self.register_background_index_update( - "event_search_event_id_idx", - index_name="event_search_event_id_idx", - table="event_search", - columns=["event_id"], - unique=True, - psql_only=True, - ) self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks @@ -554,10 +525,18 @@ class EventsStore( e_id for event in new_events for e_id in event.prev_event_ids() ) - # Finally, remove any events which are prev_events of any existing events. + # Remove any events which are prev_events of any existing events. existing_prevs = yield self._get_events_which_are_prevs(result) result.difference_update(existing_prevs) + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + defer.returnValue(result) @defer.inlineCallbacks @@ -573,12 +552,13 @@ class EventsStore( """ results = [] - def _get_events(txn, batch): + def _get_events_which_are_prevs_txn(txn, batch): sql = """ - SELECT prev_event_id + SELECT prev_event_id, internal_metadata FROM event_edges INNER JOIN events USING (event_id) LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) WHERE prev_event_id IN (%s) AND NOT events.outlier @@ -588,13 +568,85 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend(r[0] for r in txn) + results.extend( + r[0] + for r in txn + if not json.loads(r[1]).get("soft_failed") + ) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events_which_are_prevs_txn, + chunk, + ) defer.returnValue(results) + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_recursively_check), + ) + + txn.execute(sql, to_recursively_check) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_prevs_before_rejected", + _get_prevs_before_rejected_txn, + chunk, + ) + + defer.returnValue(existing_prevs) + @defer.inlineCallbacks def _get_new_state_after_events( self, room_id, events_context, old_latest_event_ids, new_latest_event_ids @@ -1325,6 +1377,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1351,6 +1406,8 @@ class EventsStore( # Insert into the event_search table. self._store_guest_access_txn(txn, event) + self._handle_event_relations(txn, event) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1493,153 +1550,6 @@ class EventsStore( ret = yield self.runInteraction("count_daily_active_rooms", _count) defer.returnValue(ret) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, json FROM events" - " INNER JOIN event_json USING (event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - - update_rows = [] - for row in rows: - try: - event_id = row[1] - event_json = json.loads(row[2]) - sender = event_json["sender"] - content = event_json["content"] - - contains_url = "url" in content - if contains_url: - contains_url &= isinstance(content["url"], text_type) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - update_rows.append((sender, contains_url, event_id)) - - sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - } - - self._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) - - defer.returnValue(result) - - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id FROM events" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - rows_to_update = [] - - chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] - for chunk in chunks: - ev_rows = self._simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, - ) - - for row in ev_rows: - event_id = row["event_id"] - event_json = json.loads(row["json"]) - try: - origin_server_ts = event_json["origin_server_ts"] - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - rows_to_update.append((origin_server_ts, event_id)) - - sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update), - } - - self._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress - ) - - return len(rows_to_update) - - result = yield self.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) - - defer.returnValue(result) - def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() @@ -1655,10 +1565,11 @@ class EventsStore( def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1673,11 +1584,12 @@ class EventsStore( sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering DESC" @@ -1698,10 +1610,11 @@ class EventsStore( def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1716,11 +1629,12 @@ class EventsStore( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" " ORDER BY event_stream_ordering DESC" diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py new file mode 100644 index 000000000..75c1935bf --- /dev/null +++ b/synapse/storage/events_bg_updates.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import text_type + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage.background_updates import BackgroundUpdateStore + +logger = logging.getLogger(__name__) + + +class EventsBackgroundUpdatesStore(BackgroundUpdateStore): + + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + + def __init__(self, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + self.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + + self.register_background_update_handler( + self.DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, + ) + + @defer.inlineCallbacks + def _background_reindex_fields_sender(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + + update_rows = [] + for row in rows: + try: + event_id = row[1] + event_json = json.loads(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], text_type) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + update_rows.append((sender, contains_url, event_id)) + + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" + + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + } + + self._background_update_progress_txn( + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update), + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their successor events, if any, and the successor events' + # rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, (batch_size,) + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + """ % ( + ",".join("?" for _ in to_check), + ) + txn.execute(sql, to_check) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph[event_id].add(prev_event_id) + continue + + graph[event_id] = {prev_event_id} + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + deleted = self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) + + if deleted: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",) + ) + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, + (room_id,) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + ) + + if not num_handled: + yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", + _drop_table_txn, + ) + + defer.returnValue(num_handled) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 663991a9b..178242804 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import division + import itertools import logging from collections import namedtuple @@ -103,7 +105,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - events = yield self._get_events( + events = yield self.get_events_as_list( [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -142,7 +144,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : Dict from event_id to event. """ - events = yield self._get_events( + events = yield self.get_events_as_list( event_ids, check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -152,13 +154,32 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @defer.inlineCallbacks - def _get_events( + def get_events_as_list( self, event_ids, check_redacted=True, get_prev_content=False, allow_rejected=False, ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[list[EventBase]]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + if not event_ids: defer.returnValue([]) @@ -202,21 +223,22 @@ class EventsWorkerStore(SQLBaseStore): # # The problem is that we end up at this point when an event # which has been redacted is pulled out of the database by - # _enqueue_events, because _enqueue_events needs to check the - # redaction before it can cache the redacted event. So obviously, - # calling get_event to get the redacted event out of the database - # gives us an infinite loop. + # _enqueue_events, because _enqueue_events needs to check + # the redaction before it can cache the redacted event. So + # obviously, calling get_event to get the redacted event out + # of the database gives us an infinite loop. # - # For now (quick hack to fix during 0.99 release cycle), we just - # go and fetch the relevant row from the db, but it would be nice - # to think about how we can cache this rather than hit the db - # every time we access a redaction event. + # For now (quick hack to fix during 0.99 release cycle), we + # just go and fetch the relevant row from the db, but it + # would be nice to think about how we can cache this rather + # than hit the db every time we access a redaction event. # # One thought on how to do this: - # 1. split _get_events up so that it is divided into (a) get the - # rawish event from the db/cache, (b) do the redaction/rejection - # filtering - # 2. have _get_event_from_row just call the first half of that + # 1. split get_events_as_list up so that it is divided into + # (a) get the rawish event from the db/cache, (b) do the + # redaction/rejection filtering + # 2. have _get_event_from_row just call the first half of + # that orig_sender = yield self._simple_select_one_onecol( table="events", @@ -590,4 +612,79 @@ class EventsWorkerStore(SQLBaseStore): return res - return self.runInteraction("get_rejection_reasons", f) + return self.runInteraction("get_seen_events_with_rejections", f) + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_total_state_event_counts. + """ + # We join against the events table as that has an index on room_id + sql = """ + SELECT COUNT(*) FROM state_events + INNER JOIN events USING (room_id, event_id) + WHERE room_id=? + """ + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, room_id + ) + + def _get_current_state_event_counts_txn(self, txn, room_id): + """ + See get_current_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_current_state_event_counts(self, room_id): + """ + Gets the current number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_current_state_event_counts", + self._get_current_state_event_counts_txn, room_id + ) + + @defer.inlineCallbacks + def get_room_complexity(self, room_id): + """ + Get a rough approximation of the complexity of the room. This is used by + remote servers to decide whether they wish to join the room or not. + Higher complexity value indicates that being in the room will consume + more resources. + + Args: + room_id (str) + + Returns: + Deferred[dict[str:int]] of complexity version to complexity. + """ + state_events = yield self.get_current_state_event_counts(room_id) + + # Call this one "v1", so we can introduce new ones as we want to develop + # it. + complexity_v1 = round(state_events / 500, 2) + + defer.returnValue({"v1": complexity_v1}) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 703654179..5300720db 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -19,6 +19,7 @@ import logging import six +import attr from signedjson.key import decode_verify_key_bytes from synapse.util import batch_iter @@ -36,6 +37,12 @@ else: db_binary_type = memoryview +@attr.s(slots=True, frozen=True) +class FetchKeyResult(object): + verify_key = attr.ib() # VerifyKey: the key itself + valid_until_ts = attr.ib() # int: how long we can use this key for + + class KeyStore(SQLBaseStore): """Persistence for signature verification keys """ @@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore): iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: - map from (server_name, key_id) -> VerifyKey, or None if the key is + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is unknown """ keys = {} @@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore): # batch_iter always returns tuples so it's safe to do len(batch) sql = ( - "SELECT server_name, key_id, verify_key FROM server_signature_keys " - "WHERE 1=0" + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" ) + " OR (server_name=? AND key_id=?)" * len(batch) txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) for row in txn: - server_name, key_id, key_bytes = row - keys[(server_name, key_id)] = decode_verify_key_bytes( - key_id, bytes(key_bytes) + server_name, key_id, key_bytes, ts_valid_until_ms = row + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, ) + keys[(server_name, key_id)] = res def _txn(txn): for batch in batch_iter(server_name_and_key_ids, 50): @@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore): return self.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_key( - self, server_name, from_server, time_now_ms, verify_key - ): - """Stores a NACL verification key for the given server. + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. Args: - server_name (str): The name of the server. - from_server (str): Where the verification key was looked up - time_now_ms (int): The time now in milliseconds - verify_key (nacl.signing.VerifyKey): The NACL verify key. + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) - - # XXX fix this to not need a lock (#3819) - def _txn(txn): - self._simple_upsert_txn( - txn, - table="server_signature_keys", - keyvalues={"server_name": server_name, "key_id": key_id}, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": db_binary_type(verify_key.encode()), - }, + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) ) # invalidate takes a tuple corresponding to the params of # _get_server_verify_key. _get_server_verify_key only takes one # param, which is itself the 2-tuple (server_name, key_id). - txn.call_after( - self._get_server_verify_key.invalidate, ((server_name, key_id),) - ) + invalidations.append((server_name, key_id)) - return self.runInteraction("store_server_verify_key", _txn) + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i, )) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 03a06a83d..4cf159ba8 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -725,17 +727,7 @@ class RegistrationStore( raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) if self._account_validity.enabled: - now_ms = self.clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - } - ) + self.set_expiration_date_for_user_txn(txn, user_id) if token: # it's possible for this to get a conflict, but only for a single user diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py new file mode 100644 index 000000000..4c83800cc --- /dev/null +++ b/synapse/storage/relations.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from twisted.internet import defer + +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.stream import generate_pagination_where_clause +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +@attr.s +class PaginationChunk(object): + """Returned by relation pagination APIs. + + Attributes: + chunk (list): The rows returned by pagination + next_batch (Any|None): Token to fetch next set of results with, if + None then there are no more results. + prev_batch (Any|None): Token to fetch previous set of results with, if + None then there are no previous results. + """ + + chunk = attr.ib() + next_batch = attr.ib(default=None) + prev_batch = attr.ib(default=None) + + def to_dict(self): + d = {"chunk": self.chunk} + + if self.next_batch: + d["next_batch"] = self.next_batch.to_string() + + if self.prev_batch: + d["prev_batch"] = self.prev_batch.to_string() + + return d + + +@attr.s(frozen=True, slots=True) +class RelationPaginationToken(object): + """Pagination token for relation pagination API. + + As the results are order by topological ordering, we can use the + `topological_ordering` and `stream_ordering` fields of the events at the + boundaries of the chunk as pagination tokens. + + Attributes: + topological (int): The topological ordering of the boundary event + stream (int): The stream ordering of the boundary event. + """ + + topological = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + t, s = string.split("-") + return RelationPaginationToken(int(t), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.topological, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +@attr.s(frozen=True, slots=True) +class AggregationPaginationToken(object): + """Pagination token for relation aggregation pagination API. + + As the results are order by count and then MAX(stream_ordering) of the + aggregation groups, we can just use them as our pagination token. + + Attributes: + count (int): The count of relations in the boundar group. + stream (int): The MAX stream ordering in the boundary group. + """ + + count = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + c, s = string.split("-") + return AggregationPaginationToken(int(c), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.count, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + defer.returnValue(edit_event) + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 57df17bcc..761791332 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,6 +142,27 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction("get_room_summary", _get_room_summary_txn) + def _get_user_counts_in_room_txn(self, txn, room_id): + """ + Get the user count in a room by membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql similarity index 83% rename from synapse/storage/schema/delta/54/account_validity.sql rename to synapse/storage/schema/delta/54/account_validity_with_renewal.sql index 235762600..0adb2ad55 100644 --- a/synapse/storage/schema/delta/54/account_validity.sql +++ b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql @@ -13,6 +13,9 @@ * limitations under the License. */ +-- We previously changed the schema for this table without renaming the file, which means +-- that some databases might still be using the old schema. This ensures Synapse uses the +-- right schema for the table. DROP TABLE IF EXISTS account_validity; -- Track what users are in public rooms. diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 000000000..c01aa9d2d --- /dev/null +++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql new file mode 100644 index 000000000..b062ec840 --- /dev/null +++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Start a background job to cleanup extremities that were incorrectly added +-- by bug #5269. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delete_soft_failed_extremities', '{}'); + +DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. +CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; +CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql new file mode 100644 index 000000000..134862b87 --- /dev/null +++ b/synapse/storage/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql new file mode 100644 index 000000000..652e58308 --- /dev/null +++ b/synapse/storage/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/schema/delta/54/stats2.sql new file mode 100644 index 000000000..3b2d48447 --- /dev/null +++ b/synapse/storage/schema/delta/54/stats2.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This delta file gets run after `54/stats.sql` delta. + +-- We want to add some indices to the temporary stats table, so we re-insert +-- 'populate_stats_createtables' if we are still processing the rooms update. +INSERT INTO background_updates (update_name, progress_json) + SELECT 'populate_stats_createtables', '{}' + WHERE + 'populate_stats_process_rooms' IN ( + SELECT update_name FROM background_updates + ) + AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists + SELECT update_name FROM background_updates + ); diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 226f8f1b7..ff49eaae0 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} @@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 56e42f583..5fdb44210 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -22,6 +22,24 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): def get_current_state_deltas(self, prev_stream_id): + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id (int): point to get changes since (exclusive) + + Returns: + Deferred[list[dict]]: results + """ prev_stream_id = int(prev_stream_id) if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id @@ -66,10 +84,16 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self._simple_select_one_onecol_txn( + txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py new file mode 100644 index 000000000..ff266b09b --- /dev/null +++ b/synapse/storage/stats.py @@ -0,0 +1,468 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.prepare_database import get_statements +from synapse.storage.state_deltas import StateDeltasStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "state_events", + ), + "user": ("public_rooms", "private_rooms"), +} + +TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +TEMP_TABLE = "_temp_populate_stats" + + +class StatsStore(StateDeltasStore): + def __init__(self, db_conn, hs): + super(StatsStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.register_background_update_handler( + "populate_stats_createtables", self._populate_stats_createtables + ) + self.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.register_background_update_handler( + "populate_stats_cleanup", self._populate_stats_cleanup + ) + + @defer.inlineCallbacks + def _populate_stats_createtables(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + # Create the temporary tables + stmts = get_statements(""" + -- We just recreate the table, we'll be reinserting the + -- correct entries again later anyway. + DROP TABLE IF EXISTS {temp}_rooms; + + CREATE TABLE IF NOT EXISTS {temp}_rooms( + room_id TEXT NOT NULL, + events BIGINT NOT NULL + ); + + CREATE INDEX {temp}_rooms_events + ON {temp}_rooms(events); + CREATE INDEX {temp}_rooms_id + ON {temp}_rooms(room_id); + """.format(temp=TEMP_TABLE).splitlines()) + + for statement in stmts: + txn.execute(statement) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database, only adding + # those that we haven't (i.e. those not in room_stats_earliest_token) + sql = """ + INSERT INTO %s_rooms (room_id, events) + SELECT c.room_id, count(*) FROM current_state_events AS c + LEFT JOIN room_stats_earliest_token AS t USING (room_id) + WHERE t.room_id IS NULL + GROUP BY c.room_id + """ % (TEMP_TABLE,) + txn.execute(sql) + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.runInteraction("populate_stats_temp_build", _make_staging_area) + yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + self.get_earliest_token_for_room_stats.invalidate_all() + + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + position = yield self._simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_stats_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_process_rooms(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_stats() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s_rooms + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE, + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.runInteraction( + "populate_stats_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + logger.info( + "Processing the next %d rooms of %d remaining", + len(rooms_to_work_on), progress["remaining"], + ) + + # Number of state events we've processed by going through each room + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + + current_state_ids = yield self.get_current_state_ids(room_id) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + history_visibility_id = current_state_ids.get( + (EventTypes.RoomHistoryVisibility, "") + ) + encryption_id = current_state_ids.get((EventTypes.RoomEncryption, "")) + name_id = current_state_ids.get((EventTypes.Name, "")) + topic_id = current_state_ids.get((EventTypes.Topic, "")) + avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) + canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) + + state_events = yield self.get_events([ + join_rules_id, history_visibility_id, encryption_id, name_id, + topic_id, avatar_id, canonical_alias_id, + ]) + + def _get_or_none(event_id, arg): + event = state_events.get(event_id) + if event: + return event.content.get(arg) + return None + + yield self.update_room_state( + room_id, + { + "join_rules": _get_or_none(join_rules_id, "join_rule"), + "history_visibility": _get_or_none( + history_visibility_id, "history_visibility" + ), + "encryption": _get_or_none(encryption_id, "algorithm"), + "name": _get_or_none(name_id, "name"), + "topic": _get_or_none(topic_id, "topic"), + "avatar": _get_or_none(avatar_id, "url"), + "canonical_alias": _get_or_none(canonical_alias_id, "alias"), + }, + ) + + now = self.hs.get_reactor().seconds() + + # quantise time to the nearest bucket + now = (now // self.stats_bucket_size) * self.stats_bucket_size + + def _fetch_data(txn): + + # Get the current token of the room + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + current_state_events = len(current_state_ids) + + membership_counts = self._get_user_counts_in_room_txn(txn, room_id) + + total_state_events = self._get_total_state_event_counts_txn( + txn, room_id + ) + + self._update_stats_txn( + txn, + "room", + room_id, + now, + { + "bucket_size": self.stats_bucket_size, + "current_state_events": current_state_events, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), + "state_events": total_state_events, + }, + ) + self._simple_insert_txn( + txn, + "room_stats_earliest_token", + {"room_id": room_id, "token": current_token}, + ) + + # We've finished a room. Delete it from the table. + self._simple_delete_one_txn( + txn, TEMP_TABLE + "_rooms", {"room_id": room_id}, + ) + + yield self.runInteraction("update_room_stats", _fetch_data) + + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_stats", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + defer.returnValue(processed_event_count) + + defer.returnValue(processed_event_count) + + def delete_all_stats(self): + """ + Delete all statistics records. + """ + + def _delete_all_stats_txn(txn): + txn.execute("DELETE FROM room_state") + txn.execute("DELETE FROM room_stats") + txn.execute("DELETE FROM room_stats_earliest_token") + txn.execute("DELETE FROM user_stats") + + return self.runInteraction("delete_all_stats", _delete_all_stats_txn) + + def get_stats_stream_pos(self): + return self._simple_select_one_onecol( + table="stats_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="stats_stream_pos", + ) + + def update_stats_stream_pos(self, stream_id): + return self._simple_update_one( + table="stats_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_stats_stream_pos", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + + # For whatever reason some of the fields may contain null bytes, which + # postgres isn't a fan of, so we replace those fields with null. + for col in ( + "join_rules", + "history_visibility", + "encryption", + "name", + "topic", + "avatar", + "canonical_alias" + ): + field = fields.get(col) + if field and "\0" in field: + fields[col] = None + + return self._simple_upsert( + table="room_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_deltas_for_room(self, room_id, start, size=100): + """ + Get statistics deltas for a given room. + + Args: + room_id (str) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS["room"] and "ts". + """ + return self._simple_select_list_paginate( + "room_stats", + {"room_id": room_id}, + "ts", + start, + size, + retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + order_direction="DESC", + ) + + def get_all_room_state(self): + return self._simple_select_list( + "room_state", None, retcols=("name", "topic", "canonical_alias") + ) + + @cached() + def get_earliest_token_for_room_stats(self, room_id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + return self._simple_select_one_onecol( + "room_stats_earliest_token", + {"room_id": room_id}, + retcol="token", + allow_none=True, + ) + + def update_stats(self, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert( + table=table, + keyvalues={id_col: stats_id, "ts": ts}, + values=fields, + desc="update_stats", + ) + + def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert_txn( + txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + ) + + def update_stats_delta(self, ts, stats_type, stats_id, field, value): + def _update_stats_delta(txn): + table, id_col = TYPE_TO_ROOM[stats_type] + + sql = ( + "SELECT * FROM %s" + " WHERE %s=? and ts=(" + " SELECT MAX(ts) FROM %s" + " WHERE %s=?" + ")" + ) % (table, id_col, table, id_col) + txn.execute(sql, (stats_id, stats_id)) + rows = self.cursor_to_dict(txn) + if len(rows) == 0: + # silently skip as we don't have anything to apply a delta to yet. + # this tries to minimise any race between the initial sync and + # subsequent deltas arriving. + return + + current_ts = ts + latest_ts = rows[0]["ts"] + if current_ts < latest_ts: + # This one is in the past, but we're just encountering it now. + # Mark it as part of the current bucket. + current_ts = latest_ts + elif ts != latest_ts: + # we have to copy our absolute counters over to the new entry. + values = { + key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] + } + values[id_col] = stats_id + values["ts"] = ts + values["bucket_size"] = self.stats_bucket_size + + self._simple_insert_txn(txn, table=table, values=values) + + # actually update the new value + if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: + self._simple_update_txn( + txn, + table=table, + keyvalues={id_col: stats_id, "ts": current_ts}, + updatevalues={field: value}, + ) + else: + sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( + table, + field, + field, + id_col, + ) + txn.execute(sql, (value, stats_id, current_ts)) + + return self.runInteraction("update_stats_delta", _update_stats_delta) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 9cd1e0f9f..529ad4ea7 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,135 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y 0, complexity) + + # Artificially raise the complexity + store = self.hs.get_datastore() + store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + + # Get the room complexity again -- make sure it's our artificial value + request, channel = self.make_request( + "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code) + complexity = channel.json_body["v1"] + self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 28e7e2741..7bb106b5f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -33,11 +33,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -46,21 +50,24 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but @@ -68,11 +75,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -81,25 +92,30 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) mock_send_transaction.reset_mock() # send the second RR - receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() mock_send_transaction.assert_not_called() @@ -111,18 +127,21 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['other_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5b2105bc7..917548bb3 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -115,11 +115,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): # We cheekily override the config to add custom alias creation rules config = {} config["alias_creation_rules"] = [ - { - "user_id": "*", - "alias": "#unofficial_*", - "action": "allow", - } + {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] config["room_list_publication_rules"] = [] @@ -162,9 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -179,10 +173,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -191,10 +182,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) @@ -202,9 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 1c49bbbc3..2e72a1dd2 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -36,7 +36,7 @@ room_keys = { "first_message_index": 1, "forwarded_count": 1, "is_verified": False, - "session_data": "SSBBTSBBIEZJU0gK" + "session_data": "SSBBTSBBIEZJU0gK", } } } @@ -47,15 +47,13 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer + self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, - handlers=None, - replication_layer=mock.Mock(), + self.addCleanup, handlers=None, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname @@ -88,67 +86,86 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # upload a new one... - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(res, "2") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "2", - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) @defer.inlineCallbacks def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - res = yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + res = yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) @defer.inlineCallbacks def test_update_missing_version(self): @@ -156,11 +173,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version(self.local_user, "1", { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1" - }) + yield self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -170,29 +191,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 400 if the version in the body is missing or doesn't match """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) @@ -223,10 +252,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can delete it @@ -255,16 +284,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = yield self.handler.get_room_keys(self.local_user, version) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest @@ -275,7 +302,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + yield self.handler.upload_room_keys( + self.local_user, "no_version", room_keys + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -285,10 +314,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None @@ -304,16 +333,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(version, "2") res = None @@ -327,10 +359,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -340,18 +372,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check getting room_keys for a given room res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org" + self.local_user, version, room_id="!abc:matrix.org" ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) self.assertDictEqual(res, room_keys) @@ -359,10 +386,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -378,7 +405,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "SSBBTSBBIEZJU0gK" + "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it @@ -387,8 +414,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # test that a session with a higher forwarded_count doesn't replace one @@ -399,8 +425,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # TODO: check edge cases as well as the common variations here @@ -409,56 +434,36 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") # check for bulk-delete yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys(self.local_user, version) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", + self.local_user, version, room_id="!abc:matrix.org" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 94c6080e3..f70c6e7d6 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, - federation_sender=Mock(), + "server", http_client=None, federation_sender=Mock() ) return hs @@ -457,7 +456,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test2 as online, test will be offline with a last_active of 0 self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -506,13 +505,13 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test as online self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) # Add servers to the room @@ -541,8 +540,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), - states=[expected_state] + destinations=set(("server2", "server3")), states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -565,7 +563,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): type=EventTypes.Member, sender=user_id, state_key=user_id, - content={"membership": Membership.JOIN} + content={"membership": Membership.JOIN}, ) prev_event_ids = self.get_success( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 017ea0385..5ffba2ca7 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -37,8 +37,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config = self.default_config("test") # some of the tests rely on us having a user consent version - hs_config.user_consent_version = "test_consent_version" - hs_config.max_mau_value = 50 + hs_config["user_consent"] = { + "version": "test_consent_version", + "template_dir": ".", + } + hs_config["max_mau_value"] = 50 + hs_config["limit_usage_by_mau"] = True hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) return hs @@ -224,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_not_support_user(self): res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) + + def test_invalid_user_id_length(self): + invalid_user_id = "x" * 256 + self.get_failure( + self.handler.register(localpart=invalid_user_id), + SynapseError + ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py new file mode 100644 index 000000000..249aba3d5 --- /dev/null +++ b/tests/handlers/test_stats.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class StatsRoomTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + + self.store = hs.get_datastore() + self.handler = self.hs.get_stats_handler() + + def _add_background_updates(self): + """ + Add the background updates we need to run. + """ + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + "depends_on": "populate_stats_createtables", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + def test_initial_room(self): + """ + The background updates will build the table from scratch. + """ + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Disable stats + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Stats disabled, shouldn't have done anything + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Enable stats + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + r = self.get_success(self.store.get_all_room_state()) + + self.assertEqual(len(r), 1) + self.assertEqual(r[0]["topic"], "foo") + + def test_initial_earliest_token(self): + """ + Ingestion via notify_new_event will ignore tokens that the background + update have already processed. + """ + self.reactor.advance(86401) + + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + u2 = self.register_user("u2", "pass") + u2_token = self.login("u2", "pass") + + u3 = self.register_user("u3", "pass") + u3_token = self.login("u3", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Begin the ingestion by creating the temp tables. This will also store + # the position that the deltas should begin at, once they take over. + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + self.store._all_done = False + self.get_success(self.store.update_stats_stream_pos(None)) + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Now, before the table is actually ingested, add some more events. + self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) + self.helper.join(room=room_1, user=u2, tok=u2_token) + + # Now do the initial ingestion. + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + self.store._all_done = False + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + self.reactor.advance(86401) + + # Now add some more events, triggering ingestion. Because of the stream + # position being set to before the events sent in the middle, a simpler + # implementation would reprocess those events, and say there were four + # users, not three. + self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) + self.helper.join(room=room_1, user=u3, tok=u3_token) + + # Get the deltas! There should be two -- day 1, and day 2. + r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + + # The oldest has 2 joined members + self.assertEqual(r[-1]["joined_members"], 2) + + # The newest has 3 + self.assertEqual(r[0]["joined_members"], 3) + + def test_incorrect_state_transition(self): + """ + If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to + (JOIN, INVITE, LEAVE, BAN), an error is raised. + """ + events = { + "a1": {"membership": Membership.LEAVE}, + "a2": {"membership": "not a real thing"}, + } + + def get_event(event_id): + m = Mock() + m.content = events[event_id] + d = defer.Deferred() + self.reactor.callLater(0.0, d.callback, m) + return d + + def get_received_ts(event_id): + return defer.succeed(1) + + self.store.get_received_ts = get_received_ts + self.store.get_event = get_event + + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a1", + "prev_event_id": "a2", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid prev_membership" + ) + + # And the other way... + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a2", + "prev_event_id": "a1", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid membership" + ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5a0b6c201..cb8b4d291 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,20 +64,22 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) hs = self.setup_test_homeserver( - datastore=(Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - )), + datastore=( + Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + ] + ) + ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, @@ -87,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): # the tests assume that we are starting at unix time 1000 - reactor.pump((1000, )) + reactor.pump((1000,)) mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -114,6 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): @@ -123,6 +126,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def get_current_users_in_room(room_id): return set(str(u) for u in self.room_members) + hs.get_state_handler().get_current_users_in_room = get_current_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( @@ -141,21 +145,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -170,12 +169,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) + ) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -216,14 +214,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -247,14 +241,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", @@ -274,18 +268,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) def test_typing_timeout(self): @@ -293,22 +279,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -320,45 +301,30 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - self.reactor.pump([16, ]) + self.reactor.pump([16]) - self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=1 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) # SYN-230 - see if we can still set after timeout - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f1d0aa42b..9021e647f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,8 +14,9 @@ # limitations under the License. from mock import Mock +import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo @@ -29,14 +30,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, ] def make_homeserver(self, reactor, clock): config = self.default_config() - config.update_user_directory = True + config["update_user_directory"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): @@ -327,12 +328,12 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): user_directory.register_servlets, room.register_servlets, login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, ] def make_homeserver(self, reactor, clock): config = self.default_config() - config.update_user_directory = True + config["update_user_directory"] = True hs = self.setup_test_homeserver(config=config) self.config = hs.config @@ -351,9 +352,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Assert user directory is not empty request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -362,9 +361,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index ee8010f59..851fc0eb3 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -24,14 +24,12 @@ def get_test_cert_file(): # # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \ # -nodes -subj '/CN=testserv' - return os.path.join( - os.path.dirname(__file__), - 'server.pem', - ) + return os.path.join(os.path.dirname(__file__), 'server.pem') class ServerTLSContext(object): """A TLS Context which presents our test cert.""" + def __init__(self): self.filename = get_test_cert_file() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index dcf184d3c..ed0ca079d 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -39,6 +39,7 @@ from synapse.util.logcontext import LoggingContext from tests.http import ServerTLSContext from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase +from tests.utils import default_config logger = logging.getLogger(__name__) @@ -53,7 +54,9 @@ class MatrixFederationAgentTests(TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(None), + tls_client_options_factory=ClientTLSOptionsFactory( + default_config("test", parse=True) + ), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, @@ -78,12 +81,12 @@ class MatrixFederationAgentTests(TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol), + FakeTransport(server_tls_protocol, self.reactor, client_protocol) ) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol), + FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) # give the reactor a pump to get the TLS juices flowing. @@ -124,7 +127,7 @@ class MatrixFederationAgentTests(TestCase): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={}, + self, client_factory, expected_sni, content, response_headers={} ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -138,8 +141,7 @@ class MatrixFederationAgentTests(TestCase): """ # make the connection for .well-known well_known_server = self._make_connection( - client_factory, - expected_sni=expected_sni, + client_factory, expected_sni=expected_sni ) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) @@ -153,17 +155,14 @@ class MatrixFederationAgentTests(TestCase): """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) def test_get(self): """ @@ -183,18 +182,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b"testserv", - ) + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] ) content = request.content.read() self.assertEqual(content, b'') @@ -243,19 +238,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'1.2.3.4'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() @@ -284,19 +273,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() @@ -325,19 +308,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 80) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]:80'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() @@ -376,7 +353,7 @@ class MatrixFederationAgentTests(TestCase): # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -386,19 +363,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -426,13 +397,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -443,8 +415,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -452,8 +423,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -489,8 +459,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) redirect_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) # send a 302 redirect @@ -499,7 +468,7 @@ class MatrixFederationAgentTests(TestCase): request.redirect(b'https://testserv/even_better_known') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # now there should be another connection clients = self.reactor.tcpClients @@ -509,8 +478,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) well_known_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) self.assertEqual(len(well_known_server.requests), 1, "No request after 302") @@ -520,11 +488,11 @@ class MatrixFederationAgentTests(TestCase): request.write(b'{ "m.server": "target-server" }') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -535,8 +503,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -544,8 +511,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -584,12 +550,12 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON', + client_factory, expected_sni=b"testserv", content=b'NOT JSON' ) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -599,19 +565,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -634,7 +594,7 @@ class MatrixFederationAgentTests(TestCase): # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # Make sure treq is trying to connect @@ -645,19 +605,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -684,17 +638,18 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443), + Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target of the SRV record @@ -705,8 +660,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -714,8 +668,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -756,7 +709,7 @@ class MatrixFederationAgentTests(TestCase): # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # We should fall back to port 8448 @@ -768,8 +721,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -777,8 +729,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -800,7 +751,7 @@ class MatrixFederationAgentTests(TestCase): self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # Make sure treq is trying to connect @@ -812,8 +763,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -821,8 +771,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -896,67 +845,70 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), - ), 100, + Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + ), + 100, ) # missing value - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=, bar']}), - )) + self.assertIsNone( + _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + ) # hackernews: bogus due to semicolon - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}), - )) + self.assertIsNone( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'private; max-age=0']}) + ) + ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), - ), 0, + Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + ), + 0, ) # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}), - ), 0, + Headers({b'cache-control': [b'private, max-age=0']}) + ), + 0, ) def test_expires(self): self.assertEqual( _cache_period_from_headers( Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), - time_now=lambda: 1548833700 - ), 33, + time_now=lambda: 1548833700, + ), + 33, ) # cache-control overrides expires self.assertEqual( _cache_period_from_headers( - Headers({ - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] - }), - time_now=lambda: 1548833700 - ), 10, + Headers( + { + b'cache-control': [b'max-age=10'], + b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + } + ), + time_now=lambda: 1548833700, + ), + 10, ) # invalid expires means immediate expiry - self.assertEqual( - _cache_period_from_headers( - Headers({b'Expires': [b'0']}), - ), 0, - ) + self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) def _check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _build_test_server(): @@ -972,7 +924,7 @@ def _build_test_server(): server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( - ServerTLSContext(), isClient=False, wrappedFactory=server_factory, + ServerTLSContext(), isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None) @@ -986,6 +938,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) class TrustingTLSPolicyForHTTPS(object): """An IPolicyForHTTPS which doesn't do any certificate verification""" + def creatorForNetloc(self, hostname, port): certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index a872e2441..034c0db8d 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -68,9 +68,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - result_deferred.callback( - ([answer_srv], None, None) - ) + result_deferred.callback(([answer_srv], None, None)) servers = self.successResultOf(test_d) @@ -112,7 +110,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver( - dns_client=dns_client_mock, cache=cache, get_time=clock.time, + dns_client=dns_client_mock, cache=cache, get_time=clock.time ) servers = yield resolver.resolve_service(service_name) @@ -168,11 +166,13 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # returning a single "." should make the lookup fail with a ConenctError - lookup_deferred.callback(( - [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], - None, - None, - )) + lookup_deferred.callback( + ( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + ) + ) self.failureResultOf(resolve_d, ConnectError) @@ -191,14 +191,16 @@ class SrvResolverTestCase(unittest.TestCase): resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) - lookup_deferred.callback(( - [ - dns.RRHeader(type=dns.A, payload=dns.Record_A()), - dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), - ], - None, - None, - )) + lookup_deferred.callback( + ( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + ) + ) servers = self.successResultOf(resolve_d) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index cd8e086f8..ee767f3a5 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -15,6 +15,8 @@ from mock import Mock +from netaddr import IPSet + from twisted.internet import defer from twisted.internet.defer import TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError @@ -36,9 +38,7 @@ from tests.unittest import HomeserverTestCase def check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): @@ -54,6 +54,7 @@ class FederationClientTests(HomeserverTestCase): """ happy-path test of a GET request """ + @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: @@ -175,8 +176,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( - f.value.inner_exception, - (ConnectingCancelledError, TimeoutError), + f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) def test_client_connect_no_response(self): @@ -211,14 +211,81 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) + def test_client_ip_range_blacklist(self): + """Ensure that Synapse does not try to connect to blacklisted IPs""" + + # Set up the ip_range blacklist + self.hs.config.federation_ip_range_blacklist = IPSet([ + "127.0.0.0/8", + "fe80::/64", + ]) + self.reactor.lookups["internal"] = "127.0.0.1" + self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" + self.reactor.lookups["fine"] = "10.20.30.40" + cl = MatrixFederationHttpClient(self.hs, None) + + # Try making a GET request to a blacklisted IPv4 address + # ------------------------------------------------------ + # Make the request + d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + + # Nothing happened yet + self.assertNoResult(d) + + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, RequestSendFailed) + self.assertIsInstance(f.value.inner_exception, DNSLookupError) + + # Try making a POST request to a blacklisted IPv6 address + # ------------------------------------------------------- + # Make the request + d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + + # Nothing has happened yet + self.assertNoResult(d) + + # Move the reactor forwards + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + # Check that it was due to a blacklisted DNS lookup + f = self.failureResultOf(d, RequestSendFailed) + self.assertIsInstance(f.value.inner_exception, DNSLookupError) + + # Try making a GET request to a non-blacklisted IPv4 address + # ---------------------------------------------------------- + # Make the request + d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + + # Nothing has happened yet + self.assertNoResult(d) + + # Move the reactor forwards + self.pump(1) + + # Check that it was able to resolve the address + clients = self.reactor.tcpClients + self.assertNotEqual(len(clients), 0) + + # Connection will still fail as this IP address does not resolve to anything + f = self.failureResultOf(d, RequestSendFailed) + self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) + def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( - method="GET", - destination="testserv:8008", - path="foo/bar", + method="GET", destination="testserv:8008", path="foo/bar" ) d = self.cl._send_request(request, timeout=10000) @@ -258,8 +325,10 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n") + ( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" + ) ) # Push by enough to time it out @@ -274,9 +343,7 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -329,9 +396,7 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -368,10 +433,7 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json( - "testserv:8008", "foo/bar", timeout=10000, - data={"a": "b"} - ) + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 0f613945c..ee0add345 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -45,7 +45,9 @@ def do_patch(): except Exception: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s on exception" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -54,7 +56,9 @@ def do_patch(): if not isinstance(res, Deferred) or res.called: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -66,9 +70,7 @@ def do_patch(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % ( - f, LoggingContext.current_context(), start_context, - ) + ) % (f, LoggingContext.current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) @@ -76,7 +78,9 @@ def do_patch(): if LoggingContext.current_context() != start_context: err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", - f, start_context, LoggingContext.current_context(), + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index be3fed8de..9cdde1a9b 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -19,7 +19,8 @@ import pkg_resources from twisted.internet.defer import Deferred -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase @@ -33,7 +34,7 @@ class EmailPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -51,22 +52,26 @@ class EmailPusherTests(HomeserverTestCase): return d config = self.default_config() - config.email_enable_notifs = True - config.start_pushers = True - - config.email_template_dir = os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') - ) - config.email_notif_template_html = "notif_mail.html" - config.email_notif_template_text = "notif_mail.txt" - config.email_smtp_host = "127.0.0.1" - config.email_smtp_port = 20 - config.require_transport_security = False - config.email_smtp_user = None - config.email_smtp_pass = None - config.email_app_name = "Matrix" - config.email_notif_from = "test@example.com" - config.email_riot_base_url = None + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "app_name": "Matrix", + "notif_from": "test@example.com", + "riot_base_url": None, + } + config["public_baseurl"] = "aaa" + config["start_pushers"] = True hs = self.setup_test_homeserver(config=config, sendmail=sendmail) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 6dc45e850..aba618b2b 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -17,7 +17,8 @@ from mock import Mock from twisted.internet.defer import Deferred -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.util.logcontext import make_deferred_yieldable from tests.unittest import HomeserverTestCase @@ -32,7 +33,7 @@ class HTTPPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -53,7 +54,7 @@ class HTTPPusherTests(HomeserverTestCase): m.post_json_get_json = post_json_get_json config = self.default_config() - config.start_pushers = True + config["start_pushers"] = True hs = self.setup_test_homeserver(config=config, simple_http_client=m) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 1f72a2a04..104349cdb 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -74,21 +74,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( master_result, expected_result, - "Expected master result to be %r but was %r" % ( - expected_result, master_result - ), + "Expected master result to be %r but was %r" + % (expected_result, master_result), ) self.assertEqual( slaved_result, expected_result, - "Expected slave result to be %r but was %r" % ( - expected_result, slaved_result - ), + "Expected slave result to be %r but was %r" + % (expected_result, slaved_result), ) self.assertEqual( master_result, slaved_result, - "Slave result %r does not match master result %r" % ( - slaved_result, master_result - ), + "Slave result %r does not match master result %r" + % (slaved_result, master_result), ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 65ecff3bd..a368117b4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,10 +234,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([ - (j2, j2ctx), - (msg, msgctx), - ])) + self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) self.replicate() event_source = RoomEventSource(self.hs) @@ -257,15 +254,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # # First, we get a list of the rooms we are joined to joined_rooms = self.get_success( - self.slaved_store.get_rooms_for_user_with_stream_ordering( - USER_ID_2, - ), + self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) ) # Then, we get a list of the events since the last sync membership_changes = self.get_success( self.slaved_store.get_membership_changes_for_user( - USER_ID_2, prev_token, current_token, + USER_ID_2, prev_token, current_token ) ) @@ -298,9 +293,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.master_store.persist_events([(event, context)], backfilled=True) ) else: - self.get_success( - self.master_store.persist_event(event, context) - ) + self.get_success(self.master_store.persist_event(event, context)) return event @@ -359,9 +352,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context( - event - )) + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 38b368a97..ce3835ae6 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -22,6 +22,7 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) @@ -52,6 +53,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): self.received_rdata_rows = [] @@ -69,6 +71,4 @@ class TestReplicationClientHandler(object): def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append( - (stream_name, token, r) - ) + self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/rest/admin/__init__.py b/tests/rest/admin/__init__.py new file mode 100644 index 000000000..1453d0457 --- /dev/null +++ b/tests/rest/admin/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/admin/test_admin.py similarity index 84% rename from tests/rest/client/v1/test_admin.py rename to tests/rest/admin/test_admin.py index c00ef21d7..e5fc2fcd1 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -19,50 +19,37 @@ import json from mock import Mock +import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, events, login, room +from synapse.http.server import JsonResource +from synapse.rest.admin import VersionServlet +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v2_alpha import groups from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): + url = '/_synapse/admin/v1/server_version' - servlets = [ - admin.register_servlets, - login.register_servlets, - ] - - url = '/_matrix/client/r0/admin/server_version' + def create_test_json_resource(self): + resource = JsonResource(self.hs) + VersionServlet(self.hs).register(resource) + return resource def test_version_string(self): - self.register_user("admin", "pass", admin=True) - self.admin_token = self.login("admin", "pass") - - request, channel = self.make_request("GET", self.url, - access_token=self.admin_token) + request, channel = self.make_request("GET", self.url, shorthand=False) self.render(request) - self.assertEqual(200, int(channel.result["code"]), - msg=channel.result["body"]) - self.assertEqual({'server_version', 'python_version'}, - set(channel.json_body.keys())) - - def test_inaccessible_to_non_admins(self): - self.register_user("unprivileged-user", "pass", admin=False) - user_token = self.login("unprivileged-user", "pass") - - request, channel = self.make_request("GET", self.url, - access_token=user_token) - self.render(request) - - self.assertEqual(403, int(channel.result['code']), - msg=channel.result['body']) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + {'server_version', 'python_version'}, set(channel.json_body.keys()) + ) class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): @@ -213,9 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin" - ) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -343,11 +328,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps({ - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid"} + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } ) request, channel = self.make_request("POST", self.url, body.encode('utf8')) self.render(request) @@ -358,7 +345,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, events.register_servlets, room.register_servlets, @@ -370,9 +357,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): hs.config.user_consent_version = "1" consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = ( - "http://example.com" - ) + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder self.store = hs.get_datastore() @@ -384,9 +369,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.other_user_token = self.login("user", "pass") # Mark the admin user as having consented - self.get_success( - self.store.user_set_consent_version(self.admin_user, "1"), - ) + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not @@ -398,9 +381,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) # Assert one user in room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([self.other_user], users_in_room) # Enable require consent to send events @@ -408,8 +389,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - room_id, - body="foo", tok=self.other_user_token, expect_code=403, + room_id, body="foo", tok=self.other_user_token, expect_code=403 ) # Test that the admin can still send shutdown @@ -425,12 +405,9 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Assert there is now no longer anyone in the room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) - @unittest.DEBUG def test_shutdown_room_block_peek(self): """Test that a world_readable room can no longer be peeked into after it has been shut down. @@ -472,30 +449,26 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) class DeleteGroupTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, groups.register_servlets, ] @@ -515,15 +488,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", "/create_group".encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -533,27 +502,17 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.admin_user_tok, - content={} + "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.other_user_token, - content={} + "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -565,15 +524,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", url.encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 self._check_group(group_id, expect_code=404) @@ -589,28 +544,22 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", - "/joined_groups".encode('ascii'), - access_token=access_token, + "GET", "/joined_groups".encode('ascii'), access_token=access_token ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 4294bbec2..88f8f1abd 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -15,8 +15,9 @@ import os +import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest @@ -31,7 +32,7 @@ except Exception: class ConsentResourceTestCase(unittest.HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -41,15 +42,18 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.user_consent_version = "1" - config.public_baseurl = "" - config.form_secret = "123abc" + config["public_baseurl"] = "aaaa" + config["form_secret"] = "123abc" # Make some temporary templates... temp_consent_path = self.mktemp() os.mkdir(temp_consent_path) os.mkdir(os.path.join(temp_consent_path, 'en')) - config.user_consent_template_dir = os.path.abspath(temp_consent_path) + + config["user_consent"] = { + "version": "1", + "template_dir": os.path.abspath(temp_consent_path), + } with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: f.write("{{version}},{{has_consented}}") diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca63b2e6e..68949307d 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -15,7 +15,8 @@ import json -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -23,7 +24,7 @@ from tests import unittest class IdentityTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -31,7 +32,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.enable_3pid_lookup = False + config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -43,7 +44,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok, + b"POST", "/createRoom", b"{}", access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -55,11 +56,9 @@ class IdentityTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ( - "/rooms/%s/invite" % (room_id) - ).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok, + b"POST", request_url, request_data, access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py new file mode 100644 index 000000000..633b7dbda --- /dev/null +++ b/tests/rest/client/v1/test_directory.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.rest import admin +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias +from synapse.util.stringutils import random_string + +from tests import unittest + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_membership_for_aliases"] = True + + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.user = self.register_user("user", "test") + self.user_tok = self.login("user", "test") + + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) + + def test_directory_endpoint_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_directory(403) + + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + + def test_directory_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(400, alias_length=256) + + def test_state_event_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + + def test_directory_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(200) + + def test_room_creation_too_long(self): + url = "/_matrix/client/r0/createRoom" + + # We use deliberately a localpart under the length threshold so + # that we can make sure that the check is done on the whole alias. + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) + + def test_room_creation(self): + url = "/_matrix/client/r0/createRoom" + + # Check with an alias of allowed length. There should already be + # a test that ensures it works in test_register.py, but let's be + # as cautious as possible here. + data = {"room_alias_name": random_string(5)} + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def set_alias_via_directory(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def random_alias(self, length): + return RoomAlias(random_string(length), self.hs.hostname).to_string() + + def ensure_user_left_room(self): + self.ensure_membership("leave") + + def ensure_user_joined_room(self): + self.ensure_membership("join") + + def ensure_membership(self, membership): + try: + if membership == "leave": + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) + if membership == "join": + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 36d854727..f340b7e85 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -17,7 +17,8 @@ from mock import Mock, NonCallableMock -from synapse.rest.client.v1 import admin, events, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import events, login, room from tests import unittest @@ -28,16 +29,16 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): servlets = [ events.register_servlets, room.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] def make_homeserver(self, reactor, clock): config = self.default_config() - config.enable_registration_captcha = False - config.enable_registration = True - config.auto_join_rooms = [] + config["enable_registration_captcha"] = False + config["enable_registration"] = True + config["auto_join_rooms"] = [] hs = self.setup_test_homeserver( config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 86312f109..0397f91a9 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,6 +1,7 @@ import json -from synapse.rest.client.v1 import admin, login +import synapse.rest.admin +from synapse.rest.client.v1 import login from tests import unittest @@ -10,7 +11,7 @@ LOGIN_URL = b"/_matrix/client/r0/login" class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] @@ -36,10 +37,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -56,14 +54,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -81,10 +76,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -101,14 +93,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -126,10 +115,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) @@ -146,14 +132,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 1eab9c3bd..72c7ed93c 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,24 +14,30 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import json + from mock import Mock from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError -from synapse.rest.client.v1 import profile +from synapse.rest import admin +from synapse.rest.client.v1 import login, profile, room from tests import unittest from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/api/v1" +PATH_PREFIX = "/_matrix/client/r0" -class ProfileTestCase(unittest.TestCase): - """ Tests profile management. """ +class MockHandlerProfileTestCase(unittest.TestCase): + """ Tests rest layer of profile management. + + Todo: move these into ProfileTestCase + """ @defer.inlineCallbacks def setUp(self): @@ -42,6 +48,7 @@ class ProfileTestCase(unittest.TestCase): "set_displayname", "get_avatar_url", "set_avatar_url", + "check_profile_query_allowed", ] ) @@ -155,3 +162,130 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") + + +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_set_displayname(self): + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test"}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test" * 100}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "owner") + + def get_displayname(self): + request, channel = self.make_request( + "GET", + "/profile/%s/displayname" % (self.owner, ), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + +class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["require_auth_for_profile_requests"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User owning the requested profile. + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.profile_url = "/profile/%s" % (self.owner) + + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) + + def test_no_auth(self): + self.try_fetch_profile(401) + + def test_not_in_shared_room(self): + self.ensure_requester_left_room() + + self.try_fetch_profile(403, access_token=self.requester_tok) + + def test_in_shared_room(self): + self.ensure_requester_left_room() + + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) + + self.try_fetch_profile(200, self.requester_tok) + + def try_fetch_profile(self, expected_code, access_token=None): + self.request_profile(expected_code, access_token=access_token) + + self.request_profile( + expected_code, url_suffix="/displayname", access_token=access_token + ) + + self.request_profile( + expected_code, url_suffix="/avatar_url", access_token=access_token + ) + + def request_profile(self, expected_code, url_suffix="", access_token=None): + request, channel = self.make_request( + "GET", self.profile_url + url_suffix, access_token=access_token + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def ensure_requester_left_room(self): + try: + self.helper.leave( + room=self.room_id, user=self.requester, tok=self.requester_tok + ) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 015c14424..5f75ad757 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +23,9 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer +import synapse.rest.admin from synapse.api.constants import Membership -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -803,7 +805,7 @@ class RoomMessageListTestCase(RoomBase): class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -903,3 +905,102 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.assertEqual( context["profile_info"][self.other_user_id]["displayname"], "otheruser" ) + + +class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/publicRooms" + + config = self.default_config() + config["restrict_public_rooms_to_local_users"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_restricted_no_auth(self): + request, channel = self.make_request("GET", self.url) + self.render(request) + self.assertEqual(channel.code, 401, channel.result) + + def test_restricted_auth(self): + self.register_user("user", "pass") + tok = self.login("user", "pass") + + request, channel = self.make_request("GET", self.url, access_token=tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + +class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["allow_per_room_profiles"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + + # Set a profile for the test user + self.displayname = "test user" + data = { + "displayname": self.displayname, + } + request_data = json.dumps(data) + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), + request_data, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_per_room_profile_forbidden(self): + data = { + "membership": "join", + "displayname": "other test user" + } + request_data = json.dumps(data) + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + self.room_id, self.user_id, + ), + request_data, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + res_displayname = channel.json_body["content"]["displayname"] + self.assertEqual(res_displayname, self.displayname, channel.result) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 05b0143c4..f7133fc12 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -127,3 +127,20 @@ class RestHelper(object): ) return channel.json_body + + def send_state(self, room_id, event_type, body, tok, expect_code=200): + path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8') + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 7fa120a10..b9ef46e8f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -16,8 +16,8 @@ from twisted.internet.defer import succeed +import synapse.rest.admin from synapse.api.constants import LoginType -from synapse.rest.client.v1 import admin from synapse.rest.client.v2_alpha import auth, register from tests import unittest @@ -27,7 +27,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -36,9 +36,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): config = self.default_config() - config.enable_registration_captcha = True - config.recaptcha_public_key = "brokencake" - config.registrations_require_3pid = [] + config["enable_registration_captcha"] = True + config["recaptcha_public_key"] = "brokencake" + config["registrations_require_3pid"] = [] hs = self.setup_test_homeserver(config=config) return hs @@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(len(self.recaptcha_attempts), 1) self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a") - # Now we have fufilled the recaptcha fallback step, we can then send a + # also complete the dummy auth + request, channel = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) + self.render(request) + + # Now we should have fufilled a complete auth flow, including + # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. request, channel = self.make_request( "POST", "register", {"auth": {"session": session}} diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index bbfc77e82..bce5b0cf4 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import admin, login +import synapse.rest.admin +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest @@ -23,7 +23,7 @@ from tests import unittest class CapabilitiesTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, login.register_servlets, ] @@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/capabilities" hs = self.setup_test_homeserver() self.store = hs.get_datastore() + self.config = hs.config return hs def test_check_auth_required(self): @@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) for room_version in capabilities['m.room_versions']['available'].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + self.assertEqual( - DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default'] + self.config.default_room_version.identifier, + capabilities['m.room_versions']['default'], ) def test_get_change_password_capabilities(self): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3d4466748..0cb6a363d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,13 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import datetime import json import os import pkg_resources +import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import admin, login +from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import account_validity, register, sync from tests import unittest @@ -40,11 +58,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): as_token = "i_am_an_app_service" appservice = ApplicationService( - as_token, self.hs.config.server_name, + as_token, + self.hs.config.server_name, id="1234", - namespaces={ - "users": [{"regex": r"@as_user.*", "exclusive": True}], - }, + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, ) self.hs.get_datastore().services_cache.append(appservice) @@ -56,10 +73,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): @@ -127,10 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) - det_data = { - "home_server": self.hs.hostname, - "device_id": "guest_device", - } + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @@ -158,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -186,7 +197,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -198,7 +209,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ register.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, account_validity.register_servlets, @@ -207,9 +218,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() # Test for account expiring after a week. - config.enable_registration = True - config.account_validity.enabled = True - config.account_validity.period = 604800000 # Time in ms for 1 week + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -220,23 +233,19 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) def test_manual_renewal(self): @@ -252,21 +261,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_matrix/client/unstable/admin/account_validity/validity" - params = { - "user_id": user_id, - } + params = {"user_id": user_id} request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -285,20 +290,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -307,7 +310,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ register.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, account_validity.register_servlets, @@ -315,14 +318,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + # Test for account expiring after a week and renewal emails being sent 2 # days before expiry. - config.enable_registration = True - config.account_validity.enabled = True - config.account_validity.renew_by_email_enabled = True - config.account_validity.period = 604800000 # Time in ms for 1 week - config.account_validity.renew_at = 172800000 # Time in ms for 2 days - config.account_validity.renew_email_subject = "Renew your account" + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + } # Email config. self.email_attempts = [] @@ -331,17 +337,23 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts.append((args, kwargs)) return - config.email_template_dir = os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') - ) - config.email_expiry_template_html = "notice_expiry.html" - config.email_expiry_template_text = "notice_expiry.txt" - config.email_smtp_host = "127.0.0.1" - config.email_smtp_port = 20 - config.require_transport_security = False - config.email_smtp_user = None - config.email_smtp_pass = None - config.email_notif_from = "test@example.com" + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) @@ -357,10 +369,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) # Move 6 days forward. This should trigger a renewal email to be sent. self.reactor.advance(datetime.timedelta(days=6).total_seconds()) @@ -378,9 +395,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -392,16 +407,65 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) request, channel = self.make_request( - b"POST", "/_matrix/client/unstable/account_validity/send_mail", + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + self.max_delta = self.validity_period * 10. / 100. + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = { + "enabled": False, + } + + self.hs = self.setup_test_homeserver(config=config) + self.hs.config.account_validity.period = self.validity_period + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.clock.time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py new file mode 100644 index 000000000..43b3049da --- /dev/null +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -0,0 +1,564 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json + +import six + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import register, relations + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": u"👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def test_deny_double_react(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + + def test_basic_paginate_relations(self): + """Tests that calling pagination API corectly the latest relations. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + channel.json_body["chunk"][0], + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), six.string_types, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly. + """ + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + + idx += 1 + idx %= len(access_tokens) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group. + """ + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=u"👍", + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redaction. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact one of the 'a' reactions + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations. + """ + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + + def test_edit(self): + """Test that a simple edit works. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + ) + + def _send_relation( + self, relation_type, event_type, key=None, content={}, access_token=None + ): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + key (str|None): The aggregation key used for m.annotation relation + type. + content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + query = "" + if key: + query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + + request, channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, self.parent_id, relation_type, event_type, query), + json.dumps(content).encode("utf-8"), + access_token=access_token, + ) + self.render(request) + return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 99b716f00..71895094b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,7 +15,8 @@ from mock import Mock -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -72,7 +73,7 @@ class FilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, sync.register_servlets, diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index af8f74eb4..00688a732 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b'inline; filename="aze%20rty"': u"aze%20rty", b'inline; filename="aze\"rty"': u'aze"rty', b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers( - { - b'Content-Disposition': [hdr], - }, - ) + res = get_filename_from_headers({b'Content-Disposition': [hdr]}) self.assertEqual( - res, expected, - "expected output for %s to be %s but was %s" % ( - hdr, expected, res, - ) + res, + expected, + "expected output for %s to be %s but was %s" % (hdr, expected, res), ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ad5e9a612..1069a4414 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -25,13 +25,11 @@ from six.moves.urllib import parse from twisted.internet import defer, reactor from twisted.internet.defer import Deferred -from synapse.config.repository import MediaStorageProviderConfig from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.util.logcontext import make_deferred_yieldable -from synapse.util.module_loader import load_module from tests import unittest @@ -120,12 +118,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): client.get_file = get_file self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) config = self.default_config() - config.media_store_path = self.storage_path - config.thumbnail_requirements = {} - config.max_image_pixels = 2000000 + config["media_store_path"] = self.media_store_path + config["thumbnail_requirements"] = {} + config["max_image_pixels"] = 2000000 provider_config = { "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", @@ -134,12 +134,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "store_remote": True, "config": {"directory": self.storage_path}, } - - loaded = list(load_module(provider_config)) + [ - MediaStorageProviderConfig(False, False, False) - ] - - config.media_storage_providers = [loaded] + config["media_storage_providers"] = [provider_config] hs = self.setup_test_homeserver(config=config, http_client=client) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 650ce95a6..1ab0f7293 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -16,7 +16,6 @@ import os import attr -from netaddr import IPSet from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -25,9 +24,6 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web._newclient import ResponseDone -from synapse.config.repository import MediaStorageProviderConfig -from synapse.util.module_loader import load_module - from tests import unittest from tests.server import FakeTransport @@ -67,23 +63,23 @@ class URLPreviewTests(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - self.storage_path = self.mktemp() - os.mkdir(self.storage_path) - config = self.default_config() - config.url_preview_enabled = True - config.max_spider_size = 9999999 - config.url_preview_ip_range_blacklist = IPSet( - ( - "192.168.1.1", - "1.0.0.0/8", - "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "2001:800::/21", - ) + config["url_preview_enabled"] = True + config["max_spider_size"] = 9999999 + config["url_preview_ip_range_blacklist"] = ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", ) - config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) - config.url_preview_url_blacklist = [] - config.media_store_path = self.storage_path + config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) + config["url_preview_url_blacklist"] = [] + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path provider_config = { "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", @@ -93,11 +89,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "config": {"directory": self.storage_path}, } - loaded = list(load_module(provider_config)) + [ - MediaStorageProviderConfig(False, False, False) - ] - - config.media_storage_providers = [loaded] + config["media_storage_providers"] = [provider_config] hs = self.setup_test_homeserver(config=config) @@ -297,12 +289,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -318,12 +310,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -339,7 +331,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, { @@ -347,6 +338,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): 'error': 'IP address blocked by IP blacklist entry', }, ) + self.assertEqual(channel.code, 403) def test_blacklisted_ip_range_direct(self): """ @@ -414,12 +406,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -439,12 +431,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -460,11 +452,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 8d8f03e00..b090bb974 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -31,27 +31,24 @@ class WellKnownTests(unittest.HomeserverTestCase): self.hs.config.default_identity_server = "https://testis" request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) self.assertEqual(request.code, 200) self.assertEqual( - channel.json_body, { + channel.json_body, + { "m.homeserver": {"base_url": "https://tesths"}, "m.identity_server": {"base_url": "https://testis"}, - } + }, ) def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) diff --git a/tests/server.py b/tests/server.py index 8f89f4a83..c15a47f2a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -182,7 +182,8 @@ def make_request( if federation_auth_origin is not None: req.requestHeaders.addRawHeader( - b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + b"Authorization", + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,), ) if content: @@ -226,6 +227,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ def __init__(self): + self.threadpool = ThreadPool(self) + self._udp = [] lookups = self.lookups = {} @@ -233,7 +236,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): class FakeResolver(object): def getHostByName(self, name, timeout=None): if name not in lookups: - return fail(DNSLookupError("OH NO: unknown %s" % (name, ))) + return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) @@ -254,6 +257,37 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.callLater(0, d.callback, True) return d + def getThreadPool(self): + return self.threadpool + + +class ThreadPool: + """ + Threadless thread pool. + """ + + def __init__(self, reactor): + self._reactor = reactor + + def start(self): + pass + + def stop(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + self._reactor.callLater(0, d.callback, True) + return d + def setup_test_homeserver(cleanup_func, *args, **kwargs): """ @@ -289,36 +323,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): **kwargs ) - class ThreadPool: - """ - Threadless thread pool. - """ - - def start(self): - pass - - def stop(self): - pass - - def callInThreadWithCallback(self, onResult, function, *args, **kwargs): - def _(res): - if isinstance(res, Failure): - onResult(False, res) - else: - onResult(True, res) - - d = Deferred() - d.addCallback(lambda x: function(*args, **kwargs)) - d.addBoth(_) - clock._reactor.callLater(0, d.callback, True) - return d - - clock.threadpool = ThreadPool() - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - pool.threadpool = ThreadPool() + pool.threadpool = ThreadPool(clock._reactor) pool.running = True return d @@ -454,6 +462,6 @@ class FakeTransport(object): logger.warning("Exception writing to protocol: %s", e) return - self.buffer = self.buffer[len(to_write):] + self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 95badc985..872039c8f 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.client.v1 import admin, login, room +import os + +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -23,27 +26,34 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): servlets = [ sync.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, room.register_servlets, ] def make_homeserver(self, reactor, clock): + tmpdir = self.mktemp() + os.mkdir(tmpdir) self.consent_notice_message = "consent %(consent_uri)s" config = self.default_config() - config.user_consent_version = "1" - config.user_consent_server_notice_content = { - "msgtype": "m.text", - "body": self.consent_notice_message, + config["user_consent"] = { + "version": "1", + "template_dir": tmpdir, + "server_notice_content": { + "msgtype": "m.text", + "body": self.consent_notice_message, + }, } - config.public_baseurl = "https://example.com/" - config.form_secret = "123abc" + config["public_baseurl"] = "https://example.com/" + config["form_secret"] = "123abc" - config.server_notices_mxid = "@notices:test" - config.server_notices_mxid_display_name = "test display name" - config.server_notices_mxid_avatar_url = None - config.server_notices_room_name = "Server Notices" + config["server_notices"] = { + "system_mxid_localpart": "notices", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + } hs = self.setup_test_homeserver(config=config) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index be73e718c..739ee59ce 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,10 +27,14 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): hs_config = self.default_config("test") - hs_config.server_notices_mxid = "@server:test" + hs_config["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + } hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) return hs @@ -80,7 +84,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False - self.hs.limit_usage_by_mau = False + self.hs.config.limit_usage_by_mau = False self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f448b0132..9c5311d91 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -50,6 +50,7 @@ class FakeEvent(object): refer to events. The event_id has node_id as localpart and example.com as domain. """ + def __init__(self, id, sender, type, state_key, content): self.node_id = id self.event_id = EventID(id, "example.com").to_string() @@ -142,24 +143,14 @@ INITIAL_EVENTS = [ content=MEMBERSHIP_CONTENT_JOIN, ), FakeEvent( - id="START", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), FakeEvent( - id="END", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), ] -INITIAL_EDGES = [ - "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", -] +INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): @@ -170,12 +161,7 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="MA", @@ -196,19 +182,11 @@ class StateTestCase(unittest.TestCase): sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), ] - edges = [ - ["END", "MB", "MA", "PA", "START"], - ["END", "PB", "PA"], - ] + edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]] expected_state_ids = ["PA", "MA", "MB"] @@ -232,10 +210,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "JR", "START"], - ["END", "ME", "START"], - ] + edges = [["END", "JR", "START"], ["END", "ME", "START"]] expected_state_ids = ["JR"] @@ -248,45 +223,25 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] - edges = [ - ["END", "PC", "PB", "PA", "START"], - ["END", "PA"], - ] + edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]] expected_state_ids = ["PC"] @@ -295,68 +250,38 @@ class StateTestCase(unittest.TestCase): def test_topic_basic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), ] - edges = [ - ["END", "PA2", "T2", "PA1", "T1", "START"], - ["END", "T3", "PB", "PA1"], - ] + edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]] expected_state_ids = ["PA2", "T2"] @@ -365,30 +290,17 @@ class StateTestCase(unittest.TestCase): def test_topic_reset(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MB", @@ -399,10 +311,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "MB", "T2", "PA", "T1", "START"], - ["END", "T1"], - ] + edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]] expected_state_ids = ["T1", "MB", "PA"] @@ -411,61 +320,34 @@ class StateTestCase(unittest.TestCase): def test_topic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MZ1", @@ -475,11 +357,7 @@ class StateTestCase(unittest.TestCase): content={}, ), FakeEvent( - id="T4", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), ] @@ -587,13 +465,7 @@ class StateTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase): def test_simple(self): - graph = { - "l": {"o"}, - "m": {"n", "o"}, - "n": {"o"}, - "o": set(), - "p": {"o"}, - } + graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -680,7 +552,13 @@ class SimpleParamStateTestCase(unittest.TestCase): self.expected_combined_state = { (e.type, e.state_key): e.event_id - for e in [create_event, alice_member, join_rules, bob_member, charlie_member] + for e in [ + create_event, + alice_member, + join_rules, + bob_member, + charlie_member, + ] } def test_event_map_none(self): @@ -720,11 +598,7 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return { - eid: self.event_map[eid] - for eid in event_ids - if eid in self.event_map - } + return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} def get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3f0083831..25a6c89ef 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store._get_events = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5568a607c..fbb930269 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup - ) + hs = yield setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f18db8c38..c778de1f0 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False hs = TestHomeServer( - "test", - db_pool=self.db_pool, - config=config, - database_engine=fake_engine, + "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) self.datastore = SQLBaseStore(None, hs) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py new file mode 100644 index 000000000..6dda66ecd --- /dev/null +++ b/tests/storage/test_cleanup_extrems.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path + +from synapse.api.constants import EventTypes +from synapse.storage import prepare_database +from synapse.types import Requester, UserID + +from tests.unittest import HomeserverTestCase + + +class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): + """Test the background update to clean forward extremities table. + """ + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.event_creator = homeserver.get_event_creation_handler() + self.room_creator = homeserver.get_room_creation_handler() + + # Create a test user and room + self.user = UserID("alice", "test") + self.requester = Requester(self.user, None, False, None, None) + info = self.get_success(self.room_creator.create_room(self.requester, {})) + self.room_id = info["room_id"] + + def create_and_send_event(self, soft_failed=False, prev_event_ids=None): + """Create and send an event. + + Args: + soft_failed (bool): Whether to create a soft failed event or not + prev_event_ids (list[str]|None): Explicitly set the prev events, + or if None just use the default + + Returns: + str: The new event's ID. + """ + prev_events_and_hashes = None + if prev_event_ids: + prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] + + event, context = self.get_success( + self.event_creator.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.user.to_string(), + "content": {"body": "", "msgtype": "m.text"}, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + ) + + if soft_failed: + event.internal_metadata.soft_failed = True + + self.get_success( + self.event_creator.send_nonmember_event(self.requester, event, context) + ) + + return event.event_id + + def add_extremity(self, event_id): + """Add the given event as an extremity to the room. + """ + self.get_success( + self.store._simple_insert( + table="event_forward_extremities", + values={"room_id": self.room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.store.get_latest_event_ids_in_room.invalidate((self.room_id,)) + + def run_background_update(self): + """Re run the background update to clean up the extremities. + """ + # Make sure we don't clash with in progress updates. + self.assertTrue(self.store._all_done, "Background updates are still ongoing") + + schema_path = os.path.join( + prepare_database.dir_path, + "schema", + "delta", + "54", + "delete_forward_extremities.sql", + ) + + def run_delta_file(txn): + prepare_database.executescript(txn, schema_path) + + self.get_success( + self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + ) + + # Ugh, have to reset this flag + self.store._all_done = False + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + def test_soft_failed_extremities_handled_correctly(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- SF2 <- B + + Where SF* are soft failed. + """ + + # Create the room graph + event_id_1 = self.create_and_send_event() + event_id_2 = self.create_and_send_event(True, [event_id_1]) + event_id_3 = self.create_and_send_event(True, [event_id_2]) + event_id_4 = self.create_and_send_event(False, [event_id_3]) + + # Check the latest events are as expected + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + + self.assertEqual(latest_event_ids, [event_id_4]) + + def test_basic_cleanup(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- B + + Where SF* are soft failed, and with extremities of A and B + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_b = self.create_and_send_event(False, [event_id_sf1]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(latest_event_ids, [event_id_b]) + + def test_chain_of_fail_cleanup(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- SF2 <- B + + Where SF* are soft failed, and with extremities of A and B + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_sf2 = self.create_and_send_event(True, [event_id_sf1]) + event_id_b = self.create_and_send_event(False, [event_id_sf2]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(latest_event_ids, [event_id_b]) + + def test_forked_graph_cleanup(self): + r"""Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like, where time flows down the page: + + A B + / \ / + / \ / + SF1 SF2 + | | + SF3 | + / \ | + | \ | + C SF4 + + Where SF* are soft failed, and with them A, B and C marked as + extremities. This should resolve to B and C being marked as extremity. + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_b = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b]) + event_id_sf3 = self.create_and_send_event(True, [event_id_sf1]) + self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4 + event_id_c = self.create_and_send_event(False, [event_id_sf3]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual( + set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) + ) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 858efe499..b62eae7ab 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -18,8 +18,9 @@ from mock import Mock from twisted.internet import defer +import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import admin, login +from synapse.rest.client.v1 import login from tests import unittest @@ -205,7 +206,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets, login.register_servlets] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver() diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 11fb8c0c1..cd2bcd4ca 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -20,7 +20,6 @@ import tests.utils class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 6bfaa00fe..e07ff0120 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ import signedjson.key from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): store = self.hs.get_datastore() - d = store.store_server_verify_key("server1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("server1", "from_server", 0, KEY_2) + key_id_1 = "ed25519:key1" + key_id_2 = "ed25519:KEY_ID_2" + d = store.store_server_verify_keys( + "from_server", + 10, + [ + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys( - [ - ("server1", "ed25519:key1"), - ("server1", "ed25519:key2"), - ("server1", "ed25519:key3"), - ] + [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] ) res = self.get_success(d) self.assertEqual(len(res.keys()), 3) - self.assertEqual(res[("server1", "ed25519:key1")].version, "key1") - self.assertEqual(res[("server1", "ed25519:key2")].version, "key2") + res1 = res[("server1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("server1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + # version comes from the ID it was stored with + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2) + d = store.store_server_verify_keys( + "from_server", + 0, + [ + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) - d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2) + d = store.store_server_verify_keys( + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6569a82b..f458c0305 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.register(user_id=user1, token="123", password_hash=None) self.store.register(user_id=user2, token="456", password_hash=None) self.store.register( - user_id=user3, token="789", - password_hash=None, user_type=UserTypes.SUPPORT + user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) @@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed( - self.hs.get_clock().time_msec() - ) + return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.store.populate_monthly_active_users('user_id') self.pump() @@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user_id=support_user_id, token="123", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) self.store.upsert_monthly_active_user(support_user_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0fc5019e9..4823d44de 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"body": body, "msgtype": u"message"}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase): "room_id": room.to_string(), "content": {"reason": reason}, "redacts": event_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index cb3cc4d2e..c0e0155bb 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user_id=SUPPORT_USER, token="456", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 063387863..73ed943f5 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"membership": membership}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 78e260a7f..b6169436d 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) @@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase): "state_key": state_key, "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups_ids( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - { - (EventTypes.Create, ''): e1.event_id, - (EventTypes.Name, ''): e2.event_id, - }, + {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, ) @defer.inlineCallbacks @@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups( - self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] - self.assertEqual( - {ev.event_id for ev in state_list}, - {e1.event_id, e2.event_id}, - ) + self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) @defer.inlineCallbacks def test_get_state_for_event(self): @@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event( - e5.event_id, - ) + state = yield self.store.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, include_others=True, - ) + ), ) self.assertStateMapEqual( @@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase): # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True ), ) @@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index fd3361404..d7d244ce9 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): yield self.store.update_profile_in_user_dir(ALICE, "alice", None) yield self.store.update_profile_in_user_dir(BOB, "bob", None) yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms( - "!room:id", (ALICE, BOB) - ) + yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 4c8f87e95..8b2741d27 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,9 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(creator), auth_events, + RoomVersions.V1.identifier, + _random_state_event(creator), + auth_events, do_sig_check=False, ) @@ -82,7 +84,9 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(king), auth_events, + RoomVersions.V1.identifier, + _random_state_event(king), + auth_events, do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 1a5dc32c8..6a8339b56 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,4 +1,3 @@ - from mock import Mock from twisted.internet.defer import maybeDeferred, succeed diff --git a/tests/test_mau.py b/tests/test_mau.py index 00be1a8c2..1fbe0d51f 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), + "red", http_client=None, federation_client=Mock() ) self.store = self.hs.get_datastore() @@ -210,9 +208,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request( - "GET", "/sync", access_token=token - ) + request, channel = self.make_request("GET", "/sync", access_token=token) self.render(request) if channel.code != 200: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 0ff6d0e28..2edbae5c6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,9 +44,7 @@ def get_sample_labels_value(sample): class TestMauLimit(unittest.TestCase): def test_basic(self): gauge = InFlightGauge( - "test1", "", - labels=["test_label"], - sub_metrics=["foo", "bar"], + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] ) def handle1(metrics): @@ -59,37 +57,49 @@ class TestMauLimit(unittest.TestCase): gauge.register(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.register(("key1",), handle1) gauge.register(("key2",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key2",), handle2) gauge.register(("key1",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) def get_metrics_from_gauge(self, gauge): results = {} diff --git a/tests/test_state.py b/tests/test_state.py index 5bcc6aaa1..6491a7105 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,7 +168,7 @@ class StateTestCase(unittest.TestCase): "get_state_resolution_handler", ] ) - hs.config = default_config("tesths") + hs.config = default_config("tesths", True) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0968e86a7..52739fbab 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -59,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase): for flow in channel.json_body["flows"]: self.assertIsInstance(flow["stages"], list) self.assertTrue(len(flow["stages"]) > 0) - self.assertEquals(flow["stages"][-1], "m.login.terms") + self.assertTrue("m.login.terms" in flow["stages"]) expected_params = { "m.login.terms": { @@ -69,10 +69,10 @@ class TermsTestCase(unittest.HomeserverTestCase): "name": "My Cool Privacy Policy", "url": "https://example.org/_matrix/consent?v=1.0", }, - "version": "1.0" - }, - }, - }, + "version": "1.0", + } + } + } } self.assertIsInstance(channel.json_body["params"], dict) self.assertDictContainsSubset(channel.json_body["params"], expected_params) diff --git a/tests/test_types.py b/tests/test_types.py index d314a7ff5..d83c36559 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -94,8 +94,7 @@ class MapUsernameTestCase(unittest.TestCase): def testSymbols(self): self.assertEqual( - map_username_to_mxid_localpart("test=$?_1234"), - "test=3d=24=3f_1234", + map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) def testLeadingUnderscore(self): @@ -105,6 +104,5 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), - "t=c3=aast", + map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index d0bc8e211..fde0baee8 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -22,6 +22,7 @@ from synapse.util.logcontext import LoggingContextFilter class ToTwistedHandler(logging.Handler): """logging handler which sends the logs to the twisted log""" + tx_log = twisted.logger.Logger() def emit(self, record): @@ -41,7 +42,8 @@ def setup_logging(): root_logger = logging.getLogger() log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" + "%(asctime)s - %(name)s - %(lineno)d - " + "%(levelname)s - %(request)s - %(message)s" ) handler = ToTwistedHandler() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 3bdb50051..6a180ddc3 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -132,7 +132,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": "", "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -153,7 +153,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -174,7 +174,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "sender": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/unittest.py b/tests/unittest.py index 8c65736a5..26204470b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,6 +27,7 @@ import twisted.logger from twisted.internet.defer import Deferred from twisted.trial import unittest +from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer @@ -84,9 +85,8 @@ class TestCase(unittest.TestCase): # all future bets are off. if LoggingContext.current_context() is not LoggingContext.sentinel: self.fail( - "Test starting with non-sentinel logging context %s" % ( - LoggingContext.current_context(), - ) + "Test starting with non-sentinel logging context %s" + % (LoggingContext.current_context(),) ) old_level = logging.getLogger().level @@ -181,10 +181,7 @@ class HomeserverTestCase(TestCase): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) # Register the resources - self.resource = JsonResource(self.hs) - - for servlet in self.servlets: - servlet(self.hs, self.resource) + self.resource = self.create_test_json_resource() from tests.rest.client.v1.utils import RestHelper @@ -230,9 +227,26 @@ class HomeserverTestCase(TestCase): hs = self.setup_test_homeserver() return hs + def create_test_json_resource(self): + """ + Create a test JsonResource, with the relevant servlets registerd to it + + The default implementation calls each function in `servlets` to do the + registration. + + Returns: + JsonResource: + """ + resource = JsonResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + def default_config(self, name="test"): """ - Get a default HomeServer config object. + Get a default HomeServer config dict. Args: name (str): The homeserver name/domain. @@ -286,7 +300,13 @@ class HomeserverTestCase(TestCase): content = json.dumps(content).encode('utf8') return make_request( - self.reactor, method, path, content, access_token, request, shorthand, + self.reactor, + method, + path, + content, + access_token, + request, + shorthand, federation_auth_origin, ) @@ -316,7 +336,14 @@ class HomeserverTestCase(TestCase): kwargs.update(self._hs_args) if "config" not in kwargs: config = self.default_config() - kwargs["config"] = config + else: + config = kwargs["config"] + + # Parse the config from a config dict into a HomeServerConfig + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config) + kwargs["config"] = config_obj + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index 84dd71e47..bf85d3b8e 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -42,10 +42,10 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) self.assertFalse(cancelled[0], "deferred was cancelled prematurely") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_times_out_when_canceller_throws(self): """Test that we have successfully worked around @@ -59,9 +59,9 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_logcontext_is_preserved_on_cancellation(self): blocking_was_cancelled = [False] @@ -80,10 +80,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), context_one, - "errback %s run in unexpected logcontext %s" % ( - deferred_name, LoggingContext.current_context(), - ) + LoggingContext.current_context(), + context_one, + "errback %s run in unexpected logcontext %s" + % (deferred_name, LoggingContext.current_context()), ) return res @@ -94,11 +94,10 @@ class TimeoutDeferredTest(TestCase): self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) timing_out_d.addErrback(errback, "timingout") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], - "non-completing deferred was not cancelled", + blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/utils.py b/tests/utils.py index cb7551485..200c1ceab 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,7 +68,9 @@ def setupdb(): # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -94,7 +96,9 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -106,7 +110,7 @@ def setupdb(): atexit.register(_cleanup) -def default_config(name): +def default_config(name, parse=False): """ Create a reasonable test config. """ @@ -114,79 +118,74 @@ def default_config(name): "server_name": name, "media_store_path": "media", "uploads_path": "uploads", - # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", + "event_cache_size": 1, + "enable_registration": True, + "enable_registration_captcha": False, + "macaroon_secret_key": "not even a little secret", + "expire_access_token": False, + "trusted_third_party_id_servers": [], + "room_invite_state_types": [], + "password_providers": [], + "worker_replication_url": "", + "worker_app": None, + "email_enable_notifs": False, + "block_non_admin_invites": False, + "federation_domain_whitelist": None, + "filter_timeline_limit": 5000, + "user_directory_search_all_users": False, + "user_consent_server_notice_content": None, + "block_events_without_consent_error": None, + "user_consent_at_registration": False, + "user_consent_policy_name": "Privacy Policy", + "media_storage_providers": [], + "autocreate_auto_join_rooms": True, + "auto_join_rooms": [], + "limit_usage_by_mau": False, + "hs_disabled": False, + "hs_disabled_message": "", + "hs_disabled_limit_type": "", + "max_mau_value": 50, + "mau_trial_days": 0, + "mau_stats_only": False, + "mau_limits_reserved_threepids": [], + "admin_contact": None, + "rc_federation": { + "reject_limit": 10, + "sleep_limit": 10, + "sleep_delay": 10, + "concurrent": 10, + }, + "rc_message": {"per_second": 10000, "burst_count": 10000}, + "rc_registration": {"per_second": 10000, "burst_count": 10000}, + "rc_login": { + "address": {"per_second": 10000, "burst_count": 10000}, + "account": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 10000, "burst_count": 10000}, + }, + "saml2_enabled": False, + "public_baseurl": None, + "default_identity_server": None, + "key_refresh_interval": 24 * 60 * 60 * 1000, + "old_signing_keys": {}, + "tls_fingerprints": [], + "use_frozen_dicts": False, + # We need a sane default_room_version, otherwise attempts to create + # rooms will fail. + "default_room_version": "1", + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + "update_user_directory": False, } - config = HomeServerConfig() - config.parse_config_dict(config_dict) + if parse: + config = HomeServerConfig() + config.parse_config_dict(config_dict) + return config - # TODO: move this stuff into config_dict or get rid of it - config.event_cache_size = 1 - config.enable_registration = True - config.enable_registration_captcha = False - config.macaroon_secret_key = "not even a little secret" - config.expire_access_token = False - config.trusted_third_party_id_servers = [] - config.room_invite_state_types = [] - config.password_providers = [] - config.worker_replication_url = "" - config.worker_app = None - config.email_enable_notifs = False - config.block_non_admin_invites = False - config.federation_domain_whitelist = None - config.federation_rc_reject_limit = 10 - config.federation_rc_sleep_limit = 10 - config.federation_rc_sleep_delay = 100 - config.federation_rc_concurrent = 10 - config.filter_timeline_limit = 5000 - config.user_directory_search_all_users = False - config.user_consent_server_notice_content = None - config.block_events_without_consent_error = None - config.user_consent_at_registration = False - config.user_consent_policy_name = "Privacy Policy" - config.media_storage_providers = [] - config.autocreate_auto_join_rooms = True - config.auto_join_rooms = [] - config.limit_usage_by_mau = False - config.hs_disabled = False - config.hs_disabled_message = "" - config.hs_disabled_limit_type = "" - config.max_mau_value = 50 - config.mau_trial_days = 0 - config.mau_stats_only = False - config.mau_limits_reserved_threepids = [] - config.admin_contact = None - config.rc_messages_per_second = 10000 - config.rc_message_burst_count = 10000 - config.rc_registration.per_second = 10000 - config.rc_registration.burst_count = 10000 - config.rc_login_address.per_second = 10000 - config.rc_login_address.burst_count = 10000 - config.rc_login_account.per_second = 10000 - config.rc_login_account.burst_count = 10000 - config.rc_login_failed_attempts.per_second = 10000 - config.rc_login_failed_attempts.burst_count = 10000 - config.saml2_enabled = False - config.public_baseurl = None - config.default_identity_server = None - config.key_refresh_interval = 24 * 60 * 60 * 1000 - config.old_signing_keys = {} - config.tls_fingerprints = [] - - config.use_frozen_dicts = False - - # we need a sane default_room_version, otherwise attempts to create rooms will - # fail. - config.default_room_version = "1" - - # disable user directory updates, because they get done in the - # background, which upsets the test runner. - config.update_user_directory = False - - return config + return config_dict class TestHomeServer(HomeServer): @@ -220,7 +219,7 @@ def setup_test_homeserver( from twisted.internet import reactor if config is None: - config = default_config(name) + config = default_config(name, parse=True) config.ldap_enabled = False @@ -377,12 +376,7 @@ def register_federation_servlets(hs, resource): resource=resource, authenticator=federation_server.Authenticator(hs), ratelimiter=FederationRateLimiter( - hs.get_clock(), - window_size=hs.config.federation_rc_window_size, - sleep_limit=hs.config.federation_rc_sleep_limit, - sleep_msec=hs.config.federation_rc_sleep_delay, - reject_limit=hs.config.federation_rc_reject_limit, - concurrent_requests=hs.config.federation_rc_concurrent, + hs.get_clock(), config=hs.config.rc_federation ), ) diff --git a/tox.ini b/tox.ini index ef543890f..543b232ae 100644 --- a/tox.ini +++ b/tox.ini @@ -24,6 +24,11 @@ deps = pip>=10 setenv = + # we have a pyproject.toml, but don't want pip to use it for building. + # (otherwise we get an error about 'editable mode is not supported for + # pyproject.toml-style projects'). + PIP_USE_PEP517 = false + PYTHONDONTWRITEBYTECODE = no_byte_code COVERAGE_PROCESS_START = {toxinidir}/.coveragerc @@ -89,7 +94,7 @@ commands = # Make all greater-thans equals so we test the oldest version of our direct # dependencies, but make the pyopenssl 17.0, which can work against an # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis). - /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs pip install' + /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' # Add this so that coverage will run on subprocesses /bin/sh -c 'echo "import coverage; coverage.process_startup()" > {envsitepackagesdir}/../sitecustomize.py'