Merge remote-tracking branch 'upstream/release-v1.25.0'

This commit is contained in:
Tulir Asokan 2021-01-06 14:44:59 +02:00
commit f461e13192
247 changed files with 7532 additions and 3444 deletions

View file

@ -5,9 +5,10 @@ jobs:
- image: docker:git - image: docker:git
steps: steps:
- checkout - checkout
- setup_remote_docker
- docker_prepare - docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
# for release builds, we want to get the amd64 image out asap, so first
# we do an amd64-only build, before following up with a multiarch build.
- docker_build: - docker_build:
tag: -t matrixdotorg/synapse:${CIRCLE_TAG} tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
platforms: linux/amd64 platforms: linux/amd64
@ -20,12 +21,10 @@ jobs:
- image: docker:git - image: docker:git
steps: steps:
- checkout - checkout
- setup_remote_docker
- docker_prepare - docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- docker_build: # for `latest`, we don't want the arm images to disappear, so don't update the tag
tag: -t matrixdotorg/synapse:latest # until all of the platforms are built.
platforms: linux/amd64
- docker_build: - docker_build:
tag: -t matrixdotorg/synapse:latest tag: -t matrixdotorg/synapse:latest
platforms: linux/amd64,linux/arm/v7,linux/arm64 platforms: linux/amd64,linux/arm/v7,linux/arm64
@ -46,12 +45,16 @@ workflows:
commands: commands:
docker_prepare: docker_prepare:
description: Downloads the buildx cli plugin and enables multiarch images description: Sets up a remote docker server, downloads the buildx cli plugin, and enables multiarch images
parameters: parameters:
buildx_version: buildx_version:
type: string type: string
default: "v0.4.1" default: "v0.4.1"
steps: steps:
- setup_remote_docker:
# 19.03.13 was the most recent available on circleci at the time of
# writing.
version: 19.03.13
- run: apk add --no-cache curl - run: apk add --no-cache curl
- run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache
- run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx

View file

@ -1,3 +1,104 @@
Synapse 1.25.0rc1 (2021-01-06)
==============================
Removal warning
---------------
The old [Purge Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/purge_room.md)
and [Shutdown Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/shutdown_room.md)
are deprecated and will be removed in a future release. They will be replaced by the
[Delete Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/rooms.md#delete-room-api).
`POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and
`POST /_synapse/admin/v1/shutdown_room/<room_id>`.
Features
--------
- Add an admin API that lets server admins get power in rooms in which local users have power. ([\#8756](https://github.com/matrix-org/synapse/issues/8756))
- Add optional HTTP authentication to replication endpoints. ([\#8853](https://github.com/matrix-org/synapse/issues/8853))
- Improve the error messages printed as a result of configuration problems for extension modules. ([\#8874](https://github.com/matrix-org/synapse/issues/8874))
- Add the number of local devices to Room Details Admin API. Contributed by @dklimpel. ([\#8886](https://github.com/matrix-org/synapse/issues/8886))
- Add `X-Robots-Tag` header to stop web crawlers from indexing media. Contributed by Aaron Raimist. ([\#8887](https://github.com/matrix-org/synapse/issues/8887))
- Spam-checkers may now define their methods as `async`. ([\#8890](https://github.com/matrix-org/synapse/issues/8890))
- Add support for allowing users to pick their own user ID during a single-sign-on login. ([\#8897](https://github.com/matrix-org/synapse/issues/8897), [\#8900](https://github.com/matrix-org/synapse/issues/8900), [\#8911](https://github.com/matrix-org/synapse/issues/8911), [\#8938](https://github.com/matrix-org/synapse/issues/8938), [\#8941](https://github.com/matrix-org/synapse/issues/8941), [\#8942](https://github.com/matrix-org/synapse/issues/8942), [\#8951](https://github.com/matrix-org/synapse/issues/8951))
- Add an `email.invite_client_location` configuration option to send a web client location to the invite endpoint on the identity server which allows customisation of the email template. ([\#8930](https://github.com/matrix-org/synapse/issues/8930))
- The search term in the list room and list user Admin APIs is now treated as case-insensitive. ([\#8931](https://github.com/matrix-org/synapse/issues/8931))
- Apply an IP range blacklist to push and key revocation requests. ([\#8821](https://github.com/matrix-org/synapse/issues/8821), [\#8870](https://github.com/matrix-org/synapse/issues/8870), [\#8954](https://github.com/matrix-org/synapse/issues/8954))
- Add an option to allow re-use of user-interactive authentication sessions for a period of time. ([\#8970](https://github.com/matrix-org/synapse/issues/8970))
- Allow running the redact endpoint on workers. ([\#8994](https://github.com/matrix-org/synapse/issues/8994))
Bugfixes
--------
- Fix bug where we might not correctly calculate the current state for rooms with multiple extremities. ([\#8827](https://github.com/matrix-org/synapse/issues/8827))
- Fix a long-standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix. ([\#8837](https://github.com/matrix-org/synapse/issues/8837))
- Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password. ([\#8858](https://github.com/matrix-org/synapse/issues/8858))
- Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource. ([\#8862](https://github.com/matrix-org/synapse/issues/8862))
- Add additional validation to pusher URLs to be compliant with the specification. ([\#8865](https://github.com/matrix-org/synapse/issues/8865))
- Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. ([\#8867](https://github.com/matrix-org/synapse/issues/8867))
- Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. ([\#8872](https://github.com/matrix-org/synapse/issues/8872))
- Fix a 500 error when attempting to preview an empty HTML file. ([\#8883](https://github.com/matrix-org/synapse/issues/8883))
- Fix occasional deadlock when handling SIGHUP. ([\#8918](https://github.com/matrix-org/synapse/issues/8918))
- Fix login API to not ratelimit application services that have ratelimiting disabled. ([\#8920](https://github.com/matrix-org/synapse/issues/8920))
- Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config). ([\#8921](https://github.com/matrix-org/synapse/issues/8921))
- Fix a bug where deactivated users appeared in the user directory when their profile information was updated. ([\#8933](https://github.com/matrix-org/synapse/issues/8933), [\#8964](https://github.com/matrix-org/synapse/issues/8964))
- Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. ([\#8937](https://github.com/matrix-org/synapse/issues/8937))
- Fix a bug where 500 errors would be returned if the `m.room_history_visibility` event had invalid content. ([\#8945](https://github.com/matrix-org/synapse/issues/8945))
- Fix a bug causing common English words to not be considered for a user directory search. ([\#8959](https://github.com/matrix-org/synapse/issues/8959))
- Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit. ([\#8962](https://github.com/matrix-org/synapse/issues/8962))
- Fix a long-standing bug where a `m.image` event without a `url` would cause errors on push. ([\#8965](https://github.com/matrix-org/synapse/issues/8965))
- Fix a small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. ([\#8971](https://github.com/matrix-org/synapse/issues/8971))
- Add validation to the `sendToDevice` API to raise a missing parameters error instead of a 500 error. ([\#8975](https://github.com/matrix-org/synapse/issues/8975))
- Add validation of group IDs to raise a 400 error instead of a 500 eror. ([\#8977](https://github.com/matrix-org/synapse/issues/8977))
Improved Documentation
----------------------
- Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules. ([\#8802](https://github.com/matrix-org/synapse/issues/8802))
- Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839))
- Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873))
- Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891))
- Moved instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987))
- Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992))
- Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002))
Deprecations and Removals
-------------------------
- Deprecate Shutdown Room and Purge Room Admin APIs. ([\#8829](https://github.com/matrix-org/synapse/issues/8829))
Internal Changes
----------------
- Properly store the mapping of external ID to Matrix ID for CAS users. ([\#8856](https://github.com/matrix-org/synapse/issues/8856), [\#8958](https://github.com/matrix-org/synapse/issues/8958))
- Remove some unnecessary stubbing from unit tests. ([\#8861](https://github.com/matrix-org/synapse/issues/8861))
- Remove unused `FakeResponse` class from unit tests. ([\#8864](https://github.com/matrix-org/synapse/issues/8864))
- Pass `room_id` to `get_auth_chain_difference`. ([\#8879](https://github.com/matrix-org/synapse/issues/8879))
- Add type hints to push module. ([\#8880](https://github.com/matrix-org/synapse/issues/8880), [\#8882](https://github.com/matrix-org/synapse/issues/8882), [\#8901](https://github.com/matrix-org/synapse/issues/8901), [\#8940](https://github.com/matrix-org/synapse/issues/8940), [\#8943](https://github.com/matrix-org/synapse/issues/8943), [\#9020](https://github.com/matrix-org/synapse/issues/9020))
- Simplify logic for handling user-interactive-auth via single-sign-on servers. ([\#8881](https://github.com/matrix-org/synapse/issues/8881))
- Skip the SAML tests if the requirements (`pysaml2` and `xmlsec1`) aren't available. ([\#8905](https://github.com/matrix-org/synapse/issues/8905))
- Fix multiarch docker image builds. ([\#8906](https://github.com/matrix-org/synapse/issues/8906))
- Don't publish `latest` docker image until all archs are built. ([\#8909](https://github.com/matrix-org/synapse/issues/8909))
- Various clean-ups to the structured logging and logging context code. ([\#8916](https://github.com/matrix-org/synapse/issues/8916), [\#8935](https://github.com/matrix-org/synapse/issues/8935))
- Automatically drop stale forward-extremities under some specific conditions. ([\#8929](https://github.com/matrix-org/synapse/issues/8929))
- Refactor test utilities for injecting HTTP requests. ([\#8946](https://github.com/matrix-org/synapse/issues/8946))
- Add a maximum size of 50 kilobytes to .well-known lookups. ([\#8950](https://github.com/matrix-org/synapse/issues/8950))
- Fix bug in `generate_log_config` script which made it write empty files. ([\#8952](https://github.com/matrix-org/synapse/issues/8952))
- Clean up tox.ini file; disable coverage checking for non-test runs. ([\#8963](https://github.com/matrix-org/synapse/issues/8963))
- Add type hints to the admin and room list handlers. ([\#8973](https://github.com/matrix-org/synapse/issues/8973))
- Add type hints to the receipts and user directory handlers. ([\#8976](https://github.com/matrix-org/synapse/issues/8976))
- Drop the unused `local_invites` table. ([\#8979](https://github.com/matrix-org/synapse/issues/8979))
- Add type hints to the base storage code. ([\#8980](https://github.com/matrix-org/synapse/issues/8980))
- Support using PyJWT v2.0.0 in the test suite. ([\#8986](https://github.com/matrix-org/synapse/issues/8986))
- Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. ([\#8998](https://github.com/matrix-org/synapse/issues/8998))
- Add type hints to the crypto module. ([\#8999](https://github.com/matrix-org/synapse/issues/8999))
Synapse 1.24.0 (2020-12-09) Synapse 1.24.0 (2020-12-09)
=========================== ===========================
@ -44,6 +145,58 @@ Internal Changes
- Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898)) - Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898))
Synapse 1.23.1 (2020-12-09)
===========================
Due to the two security issues highlighted below, server administrators are
encouraged to update Synapse. We are not aware of these vulnerabilities being
exploited in the wild.
Security advisory
-----------------
The following issues are fixed in v1.23.1 and v1.24.0.
- There is a denial of service attack
([CVE-2020-26257](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26257))
against the federation APIs in which future events will not be correctly sent
to other servers over federation. This affects all servers that participate in
open federation. (Fixed in [#8776](https://github.com/matrix-org/synapse/pull/8776)).
- Synapse may be affected by OpenSSL
[CVE-2020-1971](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-1971).
Synapse administrators should ensure that they have the latest versions of
the cryptography Python package installed.
To upgrade Synapse along with the cryptography package:
* Administrators using the [`matrix.org` Docker
image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu
packages from
`matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages)
should ensure that they have version 1.24.0 or 1.23.1 installed: these images include
the updated packages.
* Administrators who have [installed Synapse from
source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source)
should upgrade the cryptography package within their virtualenv by running:
```sh
<path_to_virtualenv>/bin/pip install 'cryptography>=3.3'
```
* Administrators who have installed Synapse from distribution packages should
consult the information from their distributions.
Bugfixes
--------
- Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. ([\#8776](https://github.com/matrix-org/synapse/issues/8776))
Internal Changes
----------------
- Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898))
Synapse 1.24.0rc2 (2020-12-04) Synapse 1.24.0rc2 (2020-12-04)
============================== ==============================

View file

@ -1,19 +1,44 @@
- [Choosing your server name](#choosing-your-server-name) # Installation Instructions
- [Picking a database engine](#picking-a-database-engine)
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates)
- [Client Well-Known URI](#client-well-known-uri)
- [Email](#email)
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
- [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation)
# Choosing your server name There are 3 steps to follow under **Installation Instructions**.
- [Installation Instructions](#installation-instructions)
- [Choosing your server name](#choosing-your-server-name)
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Debian/Ubuntu/Raspbian](#debianubunturaspbian)
- [ArchLinux](#archlinux)
- [CentOS/Fedora](#centosfedora)
- [macOS](#macos)
- [OpenSUSE](#opensuse)
- [OpenBSD](#openbsd)
- [Windows](#windows)
- [Prebuilt packages](#prebuilt-packages)
- [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks)
- [Debian/Ubuntu](#debianubuntu)
- [Matrix.org packages](#matrixorg-packages)
- [Downstream Debian packages](#downstream-debian-packages)
- [Downstream Ubuntu packages](#downstream-ubuntu-packages)
- [Fedora](#fedora)
- [OpenSUSE](#opensuse-1)
- [SUSE Linux Enterprise Server](#suse-linux-enterprise-server)
- [ArchLinux](#archlinux-1)
- [Void Linux](#void-linux)
- [FreeBSD](#freebsd)
- [OpenBSD](#openbsd-1)
- [NixOS](#nixos)
- [Setting up Synapse](#setting-up-synapse)
- [Using PostgreSQL](#using-postgresql)
- [TLS certificates](#tls-certificates)
- [Client Well-Known URI](#client-well-known-uri)
- [Email](#email)
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
- [URL previews](#url-previews)
- [Troubleshooting Installation](#troubleshooting-installation)
## Choosing your server name
It is important to choose the name for your server before you install Synapse, It is important to choose the name for your server before you install Synapse,
because it cannot be changed later. because it cannot be changed later.
@ -29,28 +54,9 @@ that your email address is probably `user@example.com` rather than
`user@email.example.com`) - but doing so may require more advanced setup: see `user@email.example.com`) - but doing so may require more advanced setup: see
[Setting up Federation](docs/federate.md). [Setting up Federation](docs/federate.md).
# Picking a database engine ## Installing Synapse
Synapse offers two database engines: ### Installing from source
* [PostgreSQL](https://www.postgresql.org)
* [SQLite](https://sqlite.org/)
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
For information on how to install and use PostgreSQL, please see
[docs/postgres.md](docs/postgres.md)
By default Synapse uses SQLite and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
# Installing Synapse
## Installing from source
(Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).) (Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).)
@ -68,7 +74,7 @@ these on various platforms.
To install the Synapse homeserver run: To install the Synapse homeserver run:
``` ```sh
mkdir -p ~/synapse mkdir -p ~/synapse
virtualenv -p python3 ~/synapse/env virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate source ~/synapse/env/bin/activate
@ -85,7 +91,7 @@ prefer.
This Synapse installation can then be later upgraded by using pip again with the This Synapse installation can then be later upgraded by using pip again with the
update flag: update flag:
``` ```sh
source ~/synapse/env/bin/activate source ~/synapse/env/bin/activate
pip install -U matrix-synapse pip install -U matrix-synapse
``` ```
@ -93,7 +99,7 @@ pip install -U matrix-synapse
Before you can start Synapse, you will need to generate a configuration Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before): file. To do this, run (in your virtualenv, as before):
``` ```sh
cd ~/synapse cd ~/synapse
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name my.domain.name \ --server-name my.domain.name \
@ -111,45 +117,43 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to
change your homeserver's keys, you may find that other homeserver have the change your homeserver's keys, you may find that other homeserver have the
old key cached. If you update the signing key, you should change the name of the old key cached. If you update the signing key, you should change the name of the
key in the `<server name>.signing.key` file (the second word) to something key in the `<server name>.signing.key` file (the second word) to something
different. See the different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management).
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
for more information on key management).
To actually run your new homeserver, pick a working directory for Synapse to To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. `~/synapse`), and: run (e.g. `~/synapse`), and:
``` ```sh
cd ~/synapse cd ~/synapse
source env/bin/activate source env/bin/activate
synctl start synctl start
``` ```
### Platform-Specific Instructions #### Platform-Specific Instructions
#### Debian/Ubuntu/Raspbian ##### Debian/Ubuntu/Raspbian
Installing prerequisites on Ubuntu or Debian: Installing prerequisites on Ubuntu or Debian:
``` ```sh
sudo apt-get install build-essential python3-dev libffi-dev \ sudo apt install build-essential python3-dev libffi-dev \
python3-pip python3-setuptools sqlite3 \ python3-pip python3-setuptools sqlite3 \
libssl-dev virtualenv libjpeg-dev libxslt1-dev libssl-dev virtualenv libjpeg-dev libxslt1-dev
``` ```
#### ArchLinux ##### ArchLinux
Installing prerequisites on ArchLinux: Installing prerequisites on ArchLinux:
``` ```sh
sudo pacman -S base-devel python python-pip \ sudo pacman -S base-devel python python-pip \
python-setuptools python-virtualenv sqlite3 python-setuptools python-virtualenv sqlite3
``` ```
#### CentOS/Fedora ##### CentOS/Fedora
Installing prerequisites on CentOS 8 or Fedora>26: Installing prerequisites on CentOS 8 or Fedora>26:
``` ```sh
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
libwebp-devel tk-devel redhat-rpm-config \ libwebp-devel tk-devel redhat-rpm-config \
python3-virtualenv libffi-devel openssl-devel python3-virtualenv libffi-devel openssl-devel
@ -158,7 +162,7 @@ sudo dnf groupinstall "Development Tools"
Installing prerequisites on CentOS 7 or Fedora<=25: Installing prerequisites on CentOS 7 or Fedora<=25:
``` ```sh
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
python3-virtualenv libffi-devel openssl-devel python3-virtualenv libffi-devel openssl-devel
@ -170,11 +174,11 @@ uses SQLite 3.7. You may be able to work around this by installing a more
recent SQLite version, but it is recommended that you instead use a Postgres recent SQLite version, but it is recommended that you instead use a Postgres
database: see [docs/postgres.md](docs/postgres.md). database: see [docs/postgres.md](docs/postgres.md).
#### macOS ##### macOS
Installing prerequisites on macOS: Installing prerequisites on macOS:
``` ```sh
xcode-select --install xcode-select --install
sudo easy_install pip sudo easy_install pip
sudo pip install virtualenv sudo pip install virtualenv
@ -184,22 +188,22 @@ brew install pkg-config libffi
On macOS Catalina (10.15) you may need to explicitly install OpenSSL On macOS Catalina (10.15) you may need to explicitly install OpenSSL
via brew and inform `pip` about it so that `psycopg2` builds: via brew and inform `pip` about it so that `psycopg2` builds:
``` ```sh
brew install openssl@1.1 brew install openssl@1.1
export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
``` ```
#### OpenSUSE ##### OpenSUSE
Installing prerequisites on openSUSE: Installing prerequisites on openSUSE:
``` ```sh
sudo zypper in -t pattern devel_basis sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel python-devel libffi-devel libopenssl-devel libjpeg62-devel
``` ```
#### OpenBSD ##### OpenBSD
A port of Synapse is available under `net/synapse`. The filesystem A port of Synapse is available under `net/synapse`. The filesystem
underlying the homeserver directory (defaults to `/var/synapse`) has to be underlying the homeserver directory (defaults to `/var/synapse`) has to be
@ -213,73 +217,72 @@ mounted with `wxallowed` (cf. `mount(8)`).
Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a
default OpenBSD installation is mounted with `wxallowed`): default OpenBSD installation is mounted with `wxallowed`):
``` ```sh
doas mkdir /usr/local/pobj_wxallowed doas mkdir /usr/local/pobj_wxallowed
``` ```
Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are
configured in `/etc/mk.conf`: configured in `/etc/mk.conf`:
``` ```sh
doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed
``` ```
Setting the `WRKOBJDIR` for building python: Setting the `WRKOBJDIR` for building python:
``` ```sh
echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf
``` ```
Building Synapse: Building Synapse:
``` ```sh
cd /usr/ports/net/synapse cd /usr/ports/net/synapse
make install make install
``` ```
#### Windows ##### Windows
If you wish to run or develop Synapse on Windows, the Windows Subsystem For If you wish to run or develop Synapse on Windows, the Windows Subsystem For
Linux provides a Linux environment on Windows 10 which is capable of using the Linux provides a Linux environment on Windows 10 which is capable of using the
Debian, Fedora, or source installation methods. More information about WSL can Debian, Fedora, or source installation methods. More information about WSL can
be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server>
for Windows Server. for Windows Server.
## Prebuilt packages ### Prebuilt packages
As an alternative to installing from source, prebuilt packages are available As an alternative to installing from source, prebuilt packages are available
for a number of platforms. for a number of platforms.
### Docker images and Ansible playbooks #### Docker images and Ansible playbooks
There is an offical synapse image available at There is an offical synapse image available at
https://hub.docker.com/r/matrixdotorg/synapse which can be used with <https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further the docker-compose file available at [contrib/docker](contrib/docker). Further
information on this including configuration options is available in the README information on this including configuration options is available in the README
on hub.docker.com. on hub.docker.com.
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at Dockerfile to automate a synapse server in a single Docker image, at
https://hub.docker.com/r/avhost/docker-matrix/tags/ <https://hub.docker.com/r/avhost/docker-matrix/tags/>
Slavi Pantaleev has created an Ansible playbook, Slavi Pantaleev has created an Ansible playbook,
which installs the offical Docker image of Matrix Synapse which installs the offical Docker image of Matrix Synapse
along with many other Matrix-related services (Postgres database, Element, coturn, along with many other Matrix-related services (Postgres database, Element, coturn,
ma1sd, SSL support, etc.). ma1sd, SSL support, etc.).
For more details, see For more details, see
https://github.com/spantaleev/matrix-docker-ansible-deploy <https://github.com/spantaleev/matrix-docker-ansible-deploy>
#### Debian/Ubuntu
### Debian/Ubuntu ##### Matrix.org packages
#### Matrix.org packages
Matrix.org provides Debian/Ubuntu packages of the latest stable version of Matrix.org provides Debian/Ubuntu packages of the latest stable version of
Synapse via https://packages.matrix.org/debian/. They are available for Debian Synapse via <https://packages.matrix.org/debian/>. They are available for Debian
9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: 9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them:
``` ```sh
sudo apt install -y lsb-release wget apt-transport-https sudo apt install -y lsb-release wget apt-transport-https
sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" | echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" |
@ -299,7 +302,7 @@ The fingerprint of the repository signing key (as shown by `gpg
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
#### Downstream Debian packages ##### Downstream Debian packages
We do not recommend using the packages from the default Debian `buster` We do not recommend using the packages from the default Debian `buster`
repository at this time, as they are old and suffer from known security repository at this time, as they are old and suffer from known security
@ -311,49 +314,49 @@ for information on how to use backports.
If you are using Debian `sid` or testing, Synapse is available in the default If you are using Debian `sid` or testing, Synapse is available in the default
repositories and it should be possible to install it simply with: repositories and it should be possible to install it simply with:
``` ```sh
sudo apt install matrix-synapse sudo apt install matrix-synapse
``` ```
#### Downstream Ubuntu packages ##### Downstream Ubuntu packages
We do not recommend using the packages in the default Ubuntu repository We do not recommend using the packages in the default Ubuntu repository
at this time, as they are old and suffer from known security vulnerabilities. at this time, as they are old and suffer from known security vulnerabilities.
The latest version of Synapse can be installed from [our repository](#matrixorg-packages). The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
### Fedora #### Fedora
Synapse is in the Fedora repositories as `matrix-synapse`: Synapse is in the Fedora repositories as `matrix-synapse`:
``` ```sh
sudo dnf install matrix-synapse sudo dnf install matrix-synapse
``` ```
Oleg Girko provides Fedora RPMs at Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse <https://obs.infoserver.lv/project/monitor/matrix-synapse>
### OpenSUSE #### OpenSUSE
Synapse is in the OpenSUSE repositories as `matrix-synapse`: Synapse is in the OpenSUSE repositories as `matrix-synapse`:
``` ```sh
sudo zypper install matrix-synapse sudo zypper install matrix-synapse
``` ```
### SUSE Linux Enterprise Server #### SUSE Linux Enterprise Server
Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at
https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/ <https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/>
### ArchLinux #### ArchLinux
The quickest way to get up and running with ArchLinux is probably with the community package The quickest way to get up and running with ArchLinux is probably with the community package
https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of <https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of
the necessary dependencies. the necessary dependencies.
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ): pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):
``` ```sh
sudo pip install --upgrade pip sudo pip install --upgrade pip
``` ```
@ -362,28 +365,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if compile it under the right architecture. (This should not be needed if
installing under virtualenv): installing under virtualenv):
``` ```sh
sudo pip uninstall py-bcrypt sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt sudo pip install py-bcrypt
``` ```
### Void Linux #### Void Linux
Synapse can be found in the void repositories as 'synapse': Synapse can be found in the void repositories as 'synapse':
``` ```sh
xbps-install -Su xbps-install -Su
xbps-install -S synapse xbps-install -S synapse
``` ```
### FreeBSD #### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py37-matrix-synapse` - Packages: `pkg install py37-matrix-synapse`
### OpenBSD #### OpenBSD
As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem
underlying the homeserver directory (defaults to `/var/synapse`) has to be underlying the homeserver directory (defaults to `/var/synapse`) has to be
@ -392,20 +395,35 @@ and mounting it to `/var/synapse` should be taken into consideration.
Installing Synapse: Installing Synapse:
``` ```sh
doas pkg_add synapse doas pkg_add synapse
``` ```
### NixOS #### NixOS
Robin Lambertz has packaged Synapse for NixOS at: Robin Lambertz has packaged Synapse for NixOS at:
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix <https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix>
# Setting up Synapse ## Setting up Synapse
Once you have installed synapse as above, you will need to configure it. Once you have installed synapse as above, you will need to configure it.
## TLS certificates ### Using PostgreSQL
By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
very light workloads.
Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include:
- significant performance improvements due to the superior threading and
caching model, smarter query optimiser
- allowing the DB to be run on separate hardware
For information on how to install and use PostgreSQL in Synapse, please see
[docs/postgres.md](docs/postgres.md)
### TLS certificates
The default configuration exposes a single HTTP port on the local The default configuration exposes a single HTTP port on the local
interface: `http://localhost:8008`. It is suitable for local testing, interface: `http://localhost:8008`. It is suitable for local testing,
@ -419,19 +437,19 @@ The recommended way to do so is to set up a reverse proxy on port
Alternatively, you can configure Synapse to expose an HTTPS port. To do Alternatively, you can configure Synapse to expose an HTTPS port. To do
so, you will need to edit `homeserver.yaml`, as follows: so, you will need to edit `homeserver.yaml`, as follows:
* First, under the `listeners` section, uncomment the configuration for the - First, under the `listeners` section, uncomment the configuration for the
TLS-enabled listener. (Remove the hash sign (`#`) at the start of TLS-enabled listener. (Remove the hash sign (`#`) at the start of
each line). The relevant lines are like this: each line). The relevant lines are like this:
``` ```yaml
- port: 8448 - port: 8448
type: http type: http
tls: true tls: true
resources: resources:
- names: [client, federation] - names: [client, federation]
``` ```
* You will also need to uncomment the `tls_certificate_path` and - You will also need to uncomment the `tls_certificate_path` and
`tls_private_key_path` lines under the `TLS` section. You will need to manage `tls_private_key_path` lines under the `TLS` section. You will need to manage
provisioning of these certificates yourself — Synapse had built-in ACME provisioning of these certificates yourself — Synapse had built-in ACME
support, but the ACMEv1 protocol Synapse implements is deprecated, not support, but the ACMEv1 protocol Synapse implements is deprecated, not
@ -446,7 +464,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
For a more detailed guide to configuring your server for federation, see For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md). [federate.md](docs/federate.md).
## Client Well-Known URI ### Client Well-Known URI
Setting up the client Well-Known URI is optional but if you set it up, it will Setting up the client Well-Known URI is optional but if you set it up, it will
allow users to enter their full username (e.g. `@user:<server_name>`) into clients allow users to enter their full username (e.g. `@user:<server_name>`) into clients
@ -457,7 +475,7 @@ about the actual homeserver URL you are using.
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
the following format. the following format.
``` ```json
{ {
"m.homeserver": { "m.homeserver": {
"base_url": "https://<matrix.example.com>" "base_url": "https://<matrix.example.com>"
@ -467,7 +485,7 @@ the following format.
It can optionally contain identity server information as well. It can optionally contain identity server information as well.
``` ```json
{ {
"m.homeserver": { "m.homeserver": {
"base_url": "https://<matrix.example.com>" "base_url": "https://<matrix.example.com>"
@ -484,7 +502,8 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
view it. view it.
In nginx this would be something like: In nginx this would be something like:
```
```nginx
location /.well-known/matrix/client { location /.well-known/matrix/client {
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}'; return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
default_type application/json; default_type application/json;
@ -497,11 +516,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to
connect to your server. This is the same URL you put for the `m.homeserver` connect to your server. This is the same URL you put for the `m.homeserver`
`base_url` above. `base_url` above.
``` ```yaml
public_baseurl: "https://<matrix.example.com>" public_baseurl: "https://<matrix.example.com>"
``` ```
## Email ### Email
It is desirable for Synapse to have the capability to send email. This allows It is desirable for Synapse to have the capability to send email. This allows
Synapse to send password reset emails, send verifications when an email address Synapse to send password reset emails, send verifications when an email address
@ -516,7 +535,7 @@ and `notif_from` fields filled out. You may also need to set `smtp_user`,
If email is not configured, password reset, registration and notifications via If email is not configured, password reset, registration and notifications via
email will be disabled. email will be disabled.
## Registering a user ### Registering a user
The easiest way to create a new user is to do so from a client like [Element](https://element.io/). The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
@ -524,7 +543,7 @@ Alternatively you can do so from the command line if you have installed via pip.
This can be done as follows: This can be done as follows:
``` ```sh
$ source ~/synapse/env/bin/activate $ source ~/synapse/env/bin/activate
$ synctl start # if not already running $ synctl start # if not already running
$ register_new_matrix_user -c homeserver.yaml http://localhost:8008 $ register_new_matrix_user -c homeserver.yaml http://localhost:8008
@ -542,12 +561,12 @@ value is generated by `--generate-config`), but it should be kept secret, as
anyone with knowledge of it can register users, including admin accounts, anyone with knowledge of it can register users, including admin accounts,
on your server even if `enable_registration` is `false`. on your server even if `enable_registration` is `false`.
## Setting up a TURN server ### Setting up a TURN server
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
## URL previews ### URL previews
Synapse includes support for previewing URLs, which is disabled by default. To Synapse includes support for previewing URLs, which is disabled by default. To
turn it on you must enable the `url_preview_enabled: True` config parameter turn it on you must enable the `url_preview_enabled: True` config parameter
@ -557,19 +576,18 @@ This is critical from a security perspective to stop arbitrary Matrix users
spidering 'internal' URLs on your network. At the very least we recommend that spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted. your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional `lxml` and `netaddr` python dependencies to be This also requires the optional `lxml` python dependency to be installed. This
installed. This in turn requires the `libxml2` library to be available - on in turn requires the `libxml2` library to be available - on Debian/Ubuntu this
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for means `apt-get install libxml2-dev`, or equivalent for your OS.
your OS.
# Troubleshooting Installation ### Troubleshooting Installation
`pip` seems to leak *lots* of memory during installation. For instance, a Linux `pip` seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.: failing, e.g.:
``` ```sh
pip install twisted pip install twisted
``` ```

View file

@ -243,6 +243,8 @@ Then update the ``users`` table in the database::
Synapse Development Synapse Development
=================== ===================
Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
Before setting up a development environment for synapse, make sure you have the Before setting up a development environment for synapse, make sure you have the
system dependencies (such as the python header files) installed - see system dependencies (such as the python header files) installed - see
`Installing from source <INSTALL.md#installing-from-source>`_. `Installing from source <INSTALL.md#installing-from-source>`_.

View file

@ -75,6 +75,27 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.25.0
====================
Blacklisting IP ranges
----------------------
Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and
``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation,
identity servers, push, and for checking key validity for third-party invite events.
The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new
``ip_range_blacklist`` defaults to private IP ranges if it is not defined.
If you have never customised ``federation_ip_range_blacklist`` it is recommended
that you remove that setting.
If you have customised ``federation_ip_range_blacklist`` you should update the
setting name to ``ip_range_blacklist``.
If you have a custom push server that is reached via private IP space you may
need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``.
Upgrading to v1.24.0 Upgrading to v1.24.0
==================== ====================

View file

@ -58,3 +58,21 @@ groups:
labels: labels:
type: "PDU" type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0' expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
labels:
type: remote
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
labels:
type: local
- record: synapse_storage_events_persisted_by_source_type
expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
labels:
type: bridges
- record: synapse_storage_events_persisted_by_event_type
expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
- record: synapse_storage_events_persisted_by_origin
expr: sum without(type) (synapse_storage_events_persisted_events_sep)

6
debian/changelog vendored
View file

@ -4,6 +4,12 @@ matrix-synapse-py3 (1.24.0) stable; urgency=medium
-- Synapse Packaging team <packages@matrix.org> Wed, 09 Dec 2020 10:14:30 +0000 -- Synapse Packaging team <packages@matrix.org> Wed, 09 Dec 2020 10:14:30 +0000
matrix-synapse-py3 (1.23.1) stable; urgency=medium
* New synapse release 1.23.1.
-- Synapse Packaging team <packages@matrix.org> Wed, 09 Dec 2020 10:40:39 +0000
matrix-synapse-py3 (1.23.0) stable; urgency=medium matrix-synapse-py3 (1.23.0) stable; urgency=medium
* New synapse release 1.23.0. * New synapse release 1.23.0.

View file

@ -69,7 +69,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \
python3-setuptools \ python3-setuptools \
python3-venv \ python3-venv \
sqlite3 \ sqlite3 \
libpq-dev libpq-dev \
xmlsec1
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /

View file

@ -1,3 +1,14 @@
# Contents
- [List all media in a room](#list-all-media-in-a-room)
- [Quarantine media](#quarantine-media)
* [Quarantining media by ID](#quarantining-media-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
- [Delete local media](#delete-local-media)
* [Delete a specific local media](#delete-a-specific-local-media)
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
- [Purge Remote Media API](#purge-remote-media-api)
# List all media in a room # List all media in a room
This API gets a list of known media in a room. This API gets a list of known media in a room.
@ -11,16 +22,16 @@ To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [README.rst](README.rst). server admin: see [README.rst](README.rst).
The API returns a JSON body like the following: The API returns a JSON body like the following:
``` ```json
{ {
"local": [ "local": [
"mxc://localhost/xwvutsrqponmlkjihgfedcba", "mxc://localhost/xwvutsrqponmlkjihgfedcba",
"mxc://localhost/abcdefghijklmnopqrstuvwx" "mxc://localhost/abcdefghijklmnopqrstuvwx"
], ],
"remote": [ "remote": [
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba", "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
"mxc://matrix.org/abcdefghijklmnopqrstuvwx" "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
] ]
} }
``` ```
@ -48,7 +59,7 @@ form of `abcdefg12345...`.
Response: Response:
``` ```json
{} {}
``` ```
@ -68,14 +79,18 @@ Where `room_id` is in the form of `!roomid12345:example.org`.
Response: Response:
``` ```json
{ {
"num_quarantined": 10 # The number of media items successfully quarantined "num_quarantined": 10
} }
``` ```
The following fields are returned in the JSON response body:
* `num_quarantined`: integer - The number of media items successfully quarantined
Note that there is a legacy endpoint, `POST Note that there is a legacy endpoint, `POST
/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same. /_synapse/admin/v1/quarantine_media/<room_id>`, that operates the same.
However, it is deprecated and may be removed in a future release. However, it is deprecated and may be removed in a future release.
## Quarantining all media of a user ## Quarantining all media of a user
@ -92,23 +107,29 @@ POST /_synapse/admin/v1/user/<user_id>/media/quarantine
{} {}
``` ```
Where `user_id` is in the form of `@bob:example.org`. URL Parameters
* `user_id`: string - User ID in the form of `@bob:example.org`
Response: Response:
``` ```json
{ {
"num_quarantined": 10 # The number of media items successfully quarantined "num_quarantined": 10
} }
``` ```
The following fields are returned in the JSON response body:
* `num_quarantined`: integer - The number of media items successfully quarantined
# Delete local media # Delete local media
This API deletes the *local* media from the disk of your own server. This API deletes the *local* media from the disk of your own server.
This includes any local thumbnails and copies of media downloaded from This includes any local thumbnails and copies of media downloaded from
remote homeservers. remote homeservers.
This API will not affect media that has been uploaded to external This API will not affect media that has been uploaded to external
media repositories (e.g https://github.com/turt2live/matrix-media-repo/). media repositories (e.g https://github.com/turt2live/matrix-media-repo/).
See also [purge_remote_media.rst](purge_remote_media.rst). See also [Purge Remote Media API](#purge-remote-media-api).
## Delete a specific local media ## Delete a specific local media
Delete a specific `media_id`. Delete a specific `media_id`.
@ -129,12 +150,12 @@ URL Parameters
Response: Response:
```json ```json
{ {
"deleted_media": [ "deleted_media": [
"abcdefghijklmnopqrstuvwx" "abcdefghijklmnopqrstuvwx"
], ],
"total": 1 "total": 1
} }
``` ```
The following fields are returned in the JSON response body: The following fields are returned in the JSON response body:
@ -167,16 +188,51 @@ If `false` these files will be deleted. Defaults to `true`.
Response: Response:
```json ```json
{ {
"deleted_media": [ "deleted_media": [
"abcdefghijklmnopqrstuvwx", "abcdefghijklmnopqrstuvwx",
"abcdefghijklmnopqrstuvwz" "abcdefghijklmnopqrstuvwz"
], ],
"total": 2 "total": 2
} }
``` ```
The following fields are returned in the JSON response body: The following fields are returned in the JSON response body:
* `deleted_media`: an array of strings - List of deleted `media_id` * `deleted_media`: an array of strings - List of deleted `media_id`
* `total`: integer - Total number of deleted `media_id` * `total`: integer - Total number of deleted `media_id`
# Purge Remote Media API
The purge remote media API allows server admins to purge old cached remote media.
The API is:
```
POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
{}
```
URL Parameters
* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in ms.
All cached media that was last accessed before this timestamp will be removed.
Response:
```json
{
"deleted": 10
}
```
The following fields are returned in the JSON response body:
* `deleted`: integer - The number of media items successfully deleted
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [README.rst](README.rst).
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View file

@ -1,20 +0,0 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
{}
\... which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View file

@ -1,12 +1,13 @@
Purge room API Deprecated: Purge room API
============== ==========================
**The old Purge room API is deprecated and will be removed in a future release.
See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
This API will remove all trace of a room from your database. This API will remove all trace of a room from your database.
All local users must have left the room before it can be removed. All local users must have left the room before it can be removed.
See also: [Delete Room API](rooms.md#delete-room-api)
The API is: The API is:
``` ```

View file

@ -1,3 +1,15 @@
# Contents
- [List Room API](#list-room-api)
* [Parameters](#parameters)
* [Usage](#usage)
- [Room Details API](#room-details-api)
- [Room Members API](#room-members-api)
- [Delete Room API](#delete-room-api)
* [Parameters](#parameters-1)
* [Response](#response)
* [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api)
# List Room API # List Room API
The List Room admin API allows server admins to get a list of rooms on their The List Room admin API allows server admins to get a list of rooms on their
@ -76,7 +88,7 @@ GET /_synapse/admin/v1/rooms
Response: Response:
``` ```jsonc
{ {
"rooms": [ "rooms": [
{ {
@ -128,7 +140,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM
Response: Response:
``` ```json
{ {
"rooms": [ "rooms": [
{ {
@ -163,7 +175,7 @@ GET /_synapse/admin/v1/rooms?order_by=size
Response: Response:
``` ```jsonc
{ {
"rooms": [ "rooms": [
{ {
@ -219,14 +231,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100
Response: Response:
``` ```jsonc
{ {
"rooms": [ "rooms": [
{ {
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory", "name": "Music Theory",
"canonical_alias": "#musictheory:matrix.org", "canonical_alias": "#musictheory:matrix.org",
"joined_members": 127 "joined_members": 127,
"joined_local_members": 2, "joined_local_members": 2,
"version": "1", "version": "1",
"creator": "@foo:matrix.org", "creator": "@foo:matrix.org",
@ -243,7 +255,7 @@ Response:
"room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk",
"name": "weechat-matrix", "name": "weechat-matrix",
"canonical_alias": "#weechat-matrix:termina.org.uk", "canonical_alias": "#weechat-matrix:termina.org.uk",
"joined_members": 137 "joined_members": 137,
"joined_local_members": 20, "joined_local_members": 20,
"version": "4", "version": "4",
"creator": "@foo:termina.org.uk", "creator": "@foo:termina.org.uk",
@ -278,6 +290,7 @@ The following fields are possible in the JSON response body:
* `canonical_alias` - The canonical (main) alias address of the room. * `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room. * `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room. * `joined_local_members` - How many local users are currently in the room.
* `joined_local_devices` - How many local devices are currently in the room.
* `version` - The version of the room as a string. * `version` - The version of the room as a string.
* `creator` - The `user_id` of the room creator. * `creator` - The `user_id` of the room creator.
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. * `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
@ -300,15 +313,16 @@ GET /_synapse/admin/v1/rooms/<room_id>
Response: Response:
``` ```json
{ {
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory", "name": "Music Theory",
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV", "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
"topic": "Theory, Composition, Notation, Analysis", "topic": "Theory, Composition, Notation, Analysis",
"canonical_alias": "#musictheory:matrix.org", "canonical_alias": "#musictheory:matrix.org",
"joined_members": 127 "joined_members": 127,
"joined_local_members": 2, "joined_local_members": 2,
"joined_local_devices": 2,
"version": "1", "version": "1",
"creator": "@foo:matrix.org", "creator": "@foo:matrix.org",
"encryption": null, "encryption": null,
@ -342,13 +356,13 @@ GET /_synapse/admin/v1/rooms/<room_id>/members
Response: Response:
``` ```json
{ {
"members": [ "members": [
"@foo:matrix.org", "@foo:matrix.org",
"@bar:matrix.org", "@bar:matrix.org",
"@foobar:matrix.org "@foobar:matrix.org"
], ],
"total": 3 "total": 3
} }
``` ```
@ -357,8 +371,6 @@ Response:
The Delete Room admin API allows server admins to remove rooms from server The Delete Room admin API allows server admins to remove rooms from server
and block these rooms. and block these rooms.
It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
and "[Purge room](purge_room.md)" API.
Shuts down a room. Moves all local users and room aliases automatically to a Shuts down a room. Moves all local users and room aliases automatically to a
new room if `new_room_user_id` is set. Otherwise local users only new room if `new_room_user_id` is set. Otherwise local users only
@ -455,3 +467,47 @@ The following fields are returned in the JSON response body:
* `local_aliases` - An array of strings representing the local aliases that were migrated from * `local_aliases` - An array of strings representing the local aliases that were migrated from
the old room to the new. the old room to the new.
* `new_room_id` - A string representing the room ID of the new room. * `new_room_id` - A string representing the room ID of the new room.
## Undoing room shutdowns
*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
the structure can and does change without notice.
First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
to recover at all:
* If the room was invite-only, your users will need to be re-invited.
* If the room no longer has any members at all, it'll be impossible to rejoin.
* The first user to rejoin will have to do so via an alias on a different server.
With all that being said, if you still want to try and recover the room:
1. For safety reasons, shut down Synapse.
2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
* For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
* The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
3. Restart Synapse.
You will have to manually handle, if you so choose, the following:
* Aliases that would have been redirected to the Content Violation room.
* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
* Removal of the Content Violation room if desired.
# Make Room Admin API
Grants another user the highest power available to a local user who is in the room.
If the user is not in the room, and it is not publicly joinable, then invite the user.
By default the server admin (the caller) is granted power, but another user can
optionally be specified, e.g.:
```
POST /_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
{
"user_id": "@foo:example.com"
}
```

View file

@ -1,4 +1,7 @@
# Shutdown room API # Deprecated: Shutdown room API
**The old Shutdown room API is deprecated and will be removed in a future release.
See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
Shuts down a room, preventing new joins and moves local users and room aliases automatically Shuts down a room, preventing new joins and moves local users and room aliases automatically
to a new room. The new room will be created with the user specified by the to a new room. The new room will be created with the user specified by the
@ -10,8 +13,6 @@ disallow any further invites or joins.
The local server will only have the power to move local user and room aliases to The local server will only have the power to move local user and room aliases to
the new room. Users on other servers will be unaffected. the new room. Users on other servers will be unaffected.
See also: [Delete Room API](rooms.md#delete-room-api)
## API ## API
You will need to authenticate with an access token for an admin user. You will need to authenticate with an access token for an admin user.

View file

@ -30,7 +30,12 @@ It returns a JSON body like the following:
], ],
"avatar_url": "<avatar_url>", "avatar_url": "<avatar_url>",
"admin": false, "admin": false,
"deactivated": false "deactivated": false,
"password_hash": "$2b$12$p9B4GkqYdRTPGD",
"creation_ts": 1560432506,
"appservice_id": null,
"consent_server_notice_sent": null,
"consent_version": null
} }
URL parameters: URL parameters:
@ -139,7 +144,6 @@ A JSON body is returned with the following shape:
"users": [ "users": [
{ {
"name": "<user_id1>", "name": "<user_id1>",
"password_hash": "<password_hash1>",
"is_guest": 0, "is_guest": 0,
"admin": 0, "admin": 0,
"user_type": null, "user_type": null,
@ -148,7 +152,6 @@ A JSON body is returned with the following shape:
"avatar_url": null "avatar_url": null
}, { }, {
"name": "<user_id2>", "name": "<user_id2>",
"password_hash": "<password_hash2>",
"is_guest": 0, "is_guest": 0,
"admin": 1, "admin": 1,
"user_type": null, "user_type": null,

View file

@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
You should now have a Django project configured to serve CAS authentication with You should now have a Django project configured to serve CAS authentication with
a single user created. a single user created.
## Configure Synapse (and Riot) to use CAS ## Configure Synapse (and Element) to use CAS
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally 1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
running Django test server: running Django test server:
@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
## Testing the configuration ## Testing the configuration
Then in Riot: Then in Element:
1. Visit the login page with a Riot pointing at your homeserver. 1. Visit the login page with a Element pointing at your homeserver.
2. Click the Single Sign-On button. 2. Click the Single Sign-On button.
3. Login using the credentials created with `createsuperuser`. 3. Login using the credentials created with `createsuperuser`.
4. You should be logged in. 4. You should be logged in.

View file

@ -144,6 +144,47 @@ pid_file: DATADIR/homeserver.pid
# #
#enable_search: false #enable_search: false
# Prevent outgoing requests from being sent to the following blacklisted IP address
# CIDR ranges. If this option is not specified then it defaults to private IP
# address ranges (see the example below).
#
# The blacklist applies to the outbound requests for federation, identity servers,
# push servers, and for checking key validity for third-party invite events.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
# This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
#
#ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '192.0.0.0/24'
# - '169.254.0.0/16'
# - '198.18.0.0/15'
# - '192.0.2.0/24'
# - '198.51.100.0/24'
# - '203.0.113.0/24'
# - '224.0.0.0/4'
# - '::1/128'
# - 'fe80::/10'
# - 'fc00::/7'
# List of IP address CIDR ranges that should be allowed for federation,
# identity servers, push servers, and for checking key validity for
# third-party invite events. This is useful for specifying exceptions to
# wide-ranging blacklisted target IP ranges - e.g. for communication with
# a push server only visible in your network.
#
# This whitelist overrides ip_range_blacklist and defaults to an empty
# list.
#
#ip_range_whitelist:
# - '192.168.1.1'
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
# #
@ -642,27 +683,6 @@ acme:
# - nyc.example.com # - nyc.example.com
# - syd.example.com # - syd.example.com
# Prevent federation requests from being sent to the following
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
# As of Synapse v1.4.0 this option also affects any outbound requests to identity
# servers provided by user input.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
federation_ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
- '192.168.0.0/16'
- '100.64.0.0/10'
- '169.254.0.0/16'
- '::1/128'
- 'fe80::/64'
- 'fc00::/7'
# Report prometheus metrics on the age of PDUs being sent to and received from # Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound # the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems # and outbound federation, though be aware that any delay can be due to problems
@ -953,9 +973,15 @@ media_store_path: "DATADIR/media_store"
# - '172.16.0.0/12' # - '172.16.0.0/12'
# - '192.168.0.0/16' # - '192.168.0.0/16'
# - '100.64.0.0/10' # - '100.64.0.0/10'
# - '192.0.0.0/24'
# - '169.254.0.0/16' # - '169.254.0.0/16'
# - '198.18.0.0/15'
# - '192.0.2.0/24'
# - '198.51.100.0/24'
# - '203.0.113.0/24'
# - '224.0.0.0/4'
# - '::1/128' # - '::1/128'
# - 'fe80::/64' # - 'fe80::/10'
# - 'fc00::/7' # - 'fc00::/7'
# List of IP address CIDR ranges that the URL preview spider is allowed # List of IP address CIDR ranges that the URL preview spider is allowed
@ -1799,9 +1825,10 @@ oidc_config:
# * user: The claims returned by the UserInfo Endpoint and/or in the ID # * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token # Token
# #
# This must be configured if using the default mapping provider. # If this is not set, the user will be prompted to choose their
# own username.
# #
localpart_template: "{{ user.preferred_username }}" #localpart_template: "{{ user.preferred_username }}"
# Jinja2 template for the display name to set on first login. # Jinja2 template for the display name to set on first login.
# #
@ -1877,11 +1904,8 @@ sso:
# - https://my.custom.client/ # - https://my.custom.client/
# Directory in which Synapse will try to find the template files below. # Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used. # If not set, or the files named below are not found within the template
# # directory, default templates from within the Synapse package will be used.
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# #
# Synapse will look for the following templates in this directory: # Synapse will look for the following templates in this directory:
# #
@ -2045,6 +2069,21 @@ password_config:
# #
#require_uppercase: true #require_uppercase: true
ui_auth:
# The number of milliseconds to allow a user-interactive authentication
# session to be active.
#
# This defaults to 0, meaning the user is queried for their credentials
# before every action, but this can be overridden to alow a single
# validation to be re-used. This weakens the protections afforded by
# the user-interactive authentication process, by allowing for multiple
# (and potentially different) operations to use the same validation session.
#
# Uncomment below to allow for credential validation to last for 15
# seconds.
#
#session_timeout: 15000
# Configuration for sending emails from Synapse. # Configuration for sending emails from Synapse.
# #
@ -2110,10 +2149,15 @@ email:
# #
#validation_token_lifetime: 15m #validation_token_lifetime: 15m
# Directory in which Synapse will try to find the template files below. # The web client location to direct users to during an invite. This is passed
# If not set, default templates from within the Synapse package will be used. # to the identity server as the org.matrix.web_client_location key. Defaults
# to unset, giving no guidance to the identity server.
# #
# Do not uncomment this setting unless you want to customise the templates. #invite_client_location: https://app.element.io
# Directory in which Synapse will try to find the template files below.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
# #
# Synapse will look for the following templates in this directory: # Synapse will look for the following templates in this directory:
# #
@ -2322,7 +2366,7 @@ spam_checker:
# If enabled, non server admins can only create groups with local parts # If enabled, non server admins can only create groups with local parts
# starting with this prefix # starting with this prefix
# #
#group_creation_prefix: "unofficial/" #group_creation_prefix: "unofficial_"
@ -2587,6 +2631,13 @@ opentracing:
# #
#run_background_tasks_on: worker1 #run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
# Configuration for Redis when using workers. This *must* be enabled when # Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration). # using workers (unless using old style direct TCP configuration).

View file

@ -22,6 +22,8 @@ well as some specific methods:
* `user_may_create_room` * `user_may_create_room`
* `user_may_create_room_alias` * `user_may_create_room_alias`
* `user_may_publish_room` * `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs) The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class. are documented in the `synapse.events.spamcheck.SpamChecker` class.
@ -32,28 +34,33 @@ call back into the homeserver internals.
### Example ### Example
```python ```python
from synapse.spam_checker_api import RegistrationBehaviour
class ExampleSpamChecker: class ExampleSpamChecker:
def __init__(self, config, api): def __init__(self, config, api):
self.config = config self.config = config
self.api = api self.api = api
def check_event_for_spam(self, foo): async def check_event_for_spam(self, foo):
return False # allow all events return False # allow all events
def user_may_invite(self, inviter_userid, invitee_userid, room_id): async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites return True # allow all invites
def user_may_create_room(self, userid): async def user_may_create_room(self, userid):
return True # allow all room creations return True # allow all room creations
def user_may_create_room_alias(self, userid, room_alias): async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases return True # allow all room aliases
def user_may_publish_room(self, userid, room_id): async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms return True # allow publishing of all rooms
def check_username_for_spam(self, user_profile): async def check_username_for_spam(self, user_profile):
return False # allow all usernames return False # allow all usernames
async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations
``` ```
## Configuration ## Configuration

View file

@ -15,12 +15,18 @@ where SAML mapping providers come into play.
SSO mapping providers are currently supported for OpenID and SAML SSO SSO mapping providers are currently supported for OpenID and SAML SSO
configurations. Please see the details below for how to implement your own. configurations. Please see the details below for how to implement your own.
It is the responsibility of the mapping provider to normalise the SSO attributes It is up to the mapping provider whether the user should be assigned a predefined
and map them to a valid Matrix ID. The Matrix ID based on the SSO attributes, or if the user should be allowed to
[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers) choose their own username.
has some information about what is considered valid. Alternately an easy way to
ensure it is valid is to use a Synapse utility function: In the first case - where users are automatically allocated a Matrix ID - it is
`synapse.types.map_username_to_mxid_localpart`. the responsibility of the mapping provider to normalise the SSO attributes and
map them to a valid Matrix ID. The [specification for Matrix
IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some
information about what is considered valid.
If the mapping provider does not assign a Matrix ID, then Synapse will
automatically serve an HTML page allowing the user to pick their own username.
External mapping providers are provided to Synapse in the form of an external External mapping providers are provided to Synapse in the form of an external
Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere,
@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods:
with failures=1. The method should then return a different with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`. `localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys: - Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID. - `localpart`: A string, used to generate the Matrix ID. If this is
- displayname: An optional string, the display name for the user. `None`, the user is prompted to pick their own username.
- `displayname`: An optional string, the display name for the user.
* `get_extra_attributes(self, userinfo, token)` * `get_extra_attributes(self, userinfo, token)`
- This method must be async. - This method must be async.
- Arguments: - Arguments:
@ -116,11 +123,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods: A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)` * `__init__(self, parsed_config, module_api)`
- Arguments: - Arguments:
- `parsed_config` - A configuration object that is the return value of the - `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by `parse_config` method. You should set any configuration options needed by
the module here. the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `parse_config(config)` * `parse_config(config)`
- This method should have the `@staticmethod` decoration. - This method should have the `@staticmethod` decoration.
- Arguments: - Arguments:
@ -163,12 +172,13 @@ A custom mapping provider must specify the following methods:
redirected to. redirected to.
- This method must return a dictionary, which will then be used by Synapse - This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed: to build a new user. The following keys are allowed:
* `mxid_localpart` - Required. The mxid localpart of the new user. * `mxid_localpart` - The mxid localpart of the new user. If this is
`None`, the user is prompted to pick their own username.
* `displayname` - The displayname of the new user. If not provided, will default to * `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`. the value of `mxid_localpart`.
* `emails` - A list of emails for the new user. If not provided, will * `emails` - A list of emails for the new user. If not provided, will
default to an empty list. default to an empty list.
Alternatively it can raise a `synapse.api.errors.RedirectException` to Alternatively it can raise a `synapse.api.errors.RedirectException` to
redirect the user to another page. This is useful to prompt the user for redirect the user to another page. This is useful to prompt the user for
additional information, e.g. if you want them to provide their own username. additional information, e.g. if you want them to provide their own username.

View file

@ -89,7 +89,8 @@ shared configuration file.
Normally, only a couple of changes are needed to make an existing configuration Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based listener" for the main process; and secondly, you need to enable redis-based
replication. For example: replication. Optionally, a shared secret can be used to authenticate HTTP
traffic between workers. For example:
```yaml ```yaml
@ -103,6 +104,9 @@ listeners:
resources: resources:
- names: [replication] - names: [replication]
# Add a random shared secret to authenticate traffic.
worker_replication_secret: ""
redis: redis:
enabled: true enabled: true
``` ```
@ -225,6 +229,7 @@ expressions:
^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
# Event sending requests # Event sending requests
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$

View file

@ -7,11 +7,17 @@ show_error_codes = True
show_traceback = True show_traceback = True
mypy_path = stubs mypy_path = stubs
warn_unreachable = True warn_unreachable = True
# To find all folders that pass mypy you run:
#
# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
files = files =
scripts-dev/sign_json, scripts-dev/sign_json,
synapse/api, synapse/api,
synapse/appservice, synapse/appservice,
synapse/config, synapse/config,
synapse/crypto,
synapse/event_auth.py, synapse/event_auth.py,
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/validator.py, synapse/events/validator.py,
@ -20,6 +26,7 @@ files =
synapse/handlers/_base.py, synapse/handlers/_base.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/account_validity.py, synapse/handlers/account_validity.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py, synapse/handlers/appservice.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
@ -38,12 +45,16 @@ files =
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/profile.py, synapse/handlers/profile.py,
synapse/handlers/read_marker.py, synapse/handlers/read_marker.py,
synapse/handlers/receipts.py,
synapse/handlers/register.py, synapse/handlers/register.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_list.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py, synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py, synapse/handlers/saml_handler.py,
synapse/handlers/sso.py,
synapse/handlers/sync.py, synapse/handlers/sync.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth, synapse/handlers/ui_auth,
synapse/http/client.py, synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/matrix_federation_agent.py,
@ -55,23 +66,34 @@ files =
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,
synapse/notifier.py, synapse/notifier.py,
synapse/push/pusherpool.py, synapse/push,
synapse/push/push_rule_evaluator.py,
synapse/replication, synapse/replication,
synapse/rest, synapse/rest,
synapse/server.py, synapse/server.py,
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/__init__.py,
synapse/storage/_base.py,
synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/keys.py,
synapse/storage/persist_events.py, synapse/storage/persist_events.py,
synapse/storage/prepare_database.py,
synapse/storage/purge_events.py,
synapse/storage/push_rule.py,
synapse/storage/relations.py,
synapse/storage/roommember.py,
synapse/storage/state.py, synapse/storage/state.py,
synapse/storage/types.py,
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
@ -108,6 +130,9 @@ ignore_missing_imports = True
[mypy-h11] [mypy-h11]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-msgpack]
ignore_missing_imports = True
[mypy-opentracing] [mypy-opentracing]
ignore_missing_imports = True ignore_missing_imports = True

View file

@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
) -> Optional[Callable[[MethodSigContext], CallableType]]: ) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith( if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__" "synapse.util.caches.descriptors._CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
): ):
return cached_function_method_signature return cached_function_method_signature
return None return None

View file

@ -40,4 +40,6 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) out = args.output_file
out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
out.flush()

View file

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.24.0" __version__ = "1.25.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -23,7 +23,7 @@ from twisted.web.server import Request
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,7 +31,9 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase from synapse.events import EventBase
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
@ -474,7 +476,7 @@ class Auth:
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
@ -646,7 +648,8 @@ class Auth:
) )
if ( if (
visibility visibility
and visibility.content["history_visibility"] == "world_readable" and visibility.content.get("history_visibility")
== HistoryVisibility.WORLD_READABLE
): ):
return Membership.JOIN, None return Membership.JOIN, None
raise AuthError( raise AuthError(

View file

@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking( async def check_auth_blocking(
self, self,
@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of # We never block the server from doing actions on behalf of
# users. # users.
return return
elif requester.app_service and not self._track_appservice_user_ips:
# If we're authenticated as an appservice then we only block
# auth if `track_appservice_user_ips` is set, as that option
# implicitly means that application services are part of MAU
# limits.
return
# Never fail an auth check for the server notices users or support user # Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking # This can be a problem where event creation is prohibited due to blocking

View file

@ -95,6 +95,8 @@ class EventTypes:
Presence = "m.presence" Presence = "m.presence"
Dummy = "org.matrix.dummy_event"
class RejectedReason: class RejectedReason:
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"
@ -160,3 +162,10 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes: class AccountDataTypes:
DIRECT = "m.direct" DIRECT = "m.direct"
IGNORED_USER_LIST = "m.ignored_user_list" IGNORED_USER_LIST = "m.ignored_user_list"
class HistoryVisibility:
INVITED = "invited"
JOINED = "joined"
SHARED = "shared"
WORLD_READABLE = "world_readable"

View file

@ -245,6 +245,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery. # Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"): if hasattr(signal, "SIGHUP"):
reactor = hs.get_reactor()
@wrap_as_background_process("sighup") @wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs): def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if # Tell systemd our state, if we're using it. This will silently fail if
@ -260,7 +262,9 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# is so that we're in a sane state, e.g. flushing the logs may fail # is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry. # if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs): def run_sighup(*args, **kwargs):
hs.get_clock().call_later(0, handle_sighup, *args, **kwargs) # `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup) signal.signal(signal.SIGHUP, run_sighup)

View file

@ -89,7 +89,7 @@ from synapse.replication.tcp.streams import (
ToDeviceStream, ToDeviceStream,
) )
from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.profile import ( from synapse.rest.client.v1.profile import (
@ -98,20 +98,6 @@ from synapse.rest.client.v1.profile import (
ProfileRestServlet, ProfileRestServlet,
) )
from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet,
JoinRoomAliasServlet,
PublicRoomListRestServlet,
RoomEventContextServlet,
RoomInitialSyncRestServlet,
RoomMemberListRestServlet,
RoomMembershipRestServlet,
RoomMessageListRestServlet,
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory from synapse.rest.client.v2_alpha import groups, sync, user_directory
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -266,7 +252,6 @@ class GenericWorkerPresence(BasePresenceHandler):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
@ -513,12 +498,6 @@ class GenericWorkerServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
RoomMemberListRestServlet(self).register(resource)
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
RoomMessageListRestServlet(self).register(resource)
RegisterRestServlet(self).register(resource) RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource) LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource)
@ -527,22 +506,19 @@ class GenericWorkerServer(HomeServer):
VoipRestServlet(self).register(resource) VoipRestServlet(self).register(resource)
PushRuleRestServlet(self).register(resource) PushRuleRestServlet(self).register(resource)
VersionsRestServlet(self).register(resource) VersionsRestServlet(self).register(resource)
RoomSendEventRestServlet(self).register(resource)
RoomMembershipRestServlet(self).register(resource)
RoomStateEventRestServlet(self).register(resource)
JoinRoomAliasServlet(self).register(resource)
ProfileAvatarURLRestServlet(self).register(resource) ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource) ProfileRestServlet(self).register(resource)
KeyUploadServlet(self).register(resource) KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource) AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource)
RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource) events.register_servlets(self, resource)
room.register_servlets(self, resource, True)
room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource) InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
user_directory.register_servlets(self, resource) user_directory.register_servlets(self, resource)

View file

@ -19,7 +19,7 @@ import gc
import logging import logging
import os import os
import sys import sys
from typing import Iterable from typing import Iterable, Iterator
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.well_known import WellKnownResource from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
@ -90,7 +91,7 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.tls tls = listener_config.tls
site_tag = listener_config.http_options.tag site_tag = listener_config.http_options.tag
if site_tag is None: if site_tag is None:
site_tag = port site_tag = str(port)
# We always include a health resource. # We always include a health resource.
resources = {"/health": HealthResource()} resources = {"/health": HealthResource()}
@ -107,7 +108,10 @@ class SynapseHomeServer(HomeServer):
logger.debug("Configuring additional resources: %r", additional_resources) logger.debug("Configuring additional resources: %r", additional_resources)
module_api = self.get_module_api() module_api = self.get_module_api()
for path, resmodule in additional_resources.items(): for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule) handler_cls, config = load_module(
resmodule,
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
)
handler = handler_cls(config, module_api) handler = handler_cls(config, module_api)
if IResource.providedBy(handler): if IResource.providedBy(handler):
resource = handler resource = handler
@ -189,6 +193,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource, "/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self), "/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self), "/_synapse/admin": AdminRestResource(self),
"/_synapse/client/pick_username": pick_username_resource(self),
} }
) )
@ -342,7 +347,10 @@ def setup(config_options):
"Synapse Homeserver", config_options "Synapse Homeserver", config_options
) )
except ConfigError as e: except ConfigError as e:
sys.stderr.write("\nERROR: %s\n" % (e,)) sys.stderr.write("\n")
for f in format_config_error(e):
sys.stderr.write(f)
sys.stderr.write("\n")
sys.exit(1) sys.exit(1)
if not config: if not config:
@ -445,6 +453,38 @@ def setup(config_options):
return hs return hs
def format_config_error(e: ConfigError) -> Iterator[str]:
"""
Formats a config error neatly
The idea is to format the immediate error, plus the "causes" of those errors,
hopefully in a way that makes sense to the user. For example:
Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
Failed to parse config for module 'JinjaOidcMappingProvider':
invalid jinja template:
unexpected end of template, expected 'end of print statement'.
Args:
e: the error to be formatted
Returns: An iterator which yields string fragments to be formatted
"""
yield "Error in configuration"
if e.path:
yield " at '%s'" % (".".join(e.path),)
yield ":\n %s" % (e.msg,)
e = e.__cause__
indent = 1
while e:
indent += 1
yield ":\n%s%s" % (" " * indent, str(e))
e = e.__cause__
class SynapseService(service.Service): class SynapseService(service.Service):
""" """
A twisted Service class that will start synapse. Used to run synapse A twisted Service class that will start synapse. Used to run synapse

View file

@ -23,7 +23,7 @@ import urllib.parse
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, List, MutableMapping, Optional from typing import Any, Callable, Iterable, List, MutableMapping, Optional
import attr import attr
import jinja2 import jinja2
@ -32,7 +32,17 @@ import yaml
class ConfigError(Exception): class ConfigError(Exception):
pass """Represents a problem parsing the configuration
Args:
msg: A textual description of the error.
path: Where appropriate, an indication of where in the configuration
the problem lies.
"""
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
self.msg = msg
self.path = path
# We split these messages out to allow packages to override with package # We split these messages out to allow packages to override with package

View file

@ -1,8 +1,9 @@
from typing import Any, List, Optional from typing import Any, Iterable, List, Optional
from synapse.config import ( from synapse.config import (
api, api,
appservice, appservice,
auth,
captcha, captcha,
cas, cas,
consent_config, consent_config,
@ -14,7 +15,6 @@ from synapse.config import (
logger, logger,
metrics, metrics,
oidc_config, oidc_config,
password,
password_auth_providers, password_auth_providers,
push, push,
ratelimiting, ratelimiting,
@ -35,7 +35,10 @@ from synapse.config import (
workers, workers,
) )
class ConfigError(Exception): ... class ConfigError(Exception):
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
self.msg = msg
self.path = path
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str MISSING_REPORT_STATS_SPIEL: str
@ -62,7 +65,7 @@ class RootConfig:
sso: sso.SSOConfig sso: sso.SSOConfig
oidc: oidc_config.OIDCConfig oidc: oidc_config.OIDCConfig
jwt: jwt_config.JWTConfig jwt: jwt_config.JWTConfig
password: password.PasswordConfig auth: auth.AuthConfig
email: emailconfig.EmailConfig email: emailconfig.EmailConfig
worker: workers.WorkerConfig worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig authproviders: password_auth_providers.PasswordAuthProviderConfig

View file

@ -38,14 +38,27 @@ def validate_config(
try: try:
jsonschema.validate(config, json_schema) jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e: except jsonschema.ValidationError as e:
# copy `config_path` before modifying it. raise json_error_to_config_error(e, config_path)
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
raise ConfigError(
"Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) def json_error_to_config_error(
) e: jsonschema.ValidationError, config_path: Iterable[str]
) -> ConfigError:
"""Converts a json validation error to a user-readable ConfigError
Args:
e: the exception to be converted
config_path: the path within the config file. This will be used as a basis
for the error message.
Returns:
a ConfigError
"""
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
return ConfigError(e.message, path)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,11 +17,11 @@
from ._base import Config from ._base import Config
class PasswordConfig(Config): class AuthConfig(Config):
"""Password login configuration """Password and login configuration
""" """
section = "password" section = "auth"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
@ -35,6 +36,10 @@ class PasswordConfig(Config):
self.password_policy = password_config.get("policy") or {} self.password_policy = password_config.get("policy") or {}
self.password_policy_enabled = self.password_policy.get("enabled", False) self.password_policy_enabled = self.password_policy.get("enabled", False)
# User-interactive authentication
ui_auth = config.get("ui_auth") or {}
self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
password_config: password_config:
@ -87,4 +92,19 @@ class PasswordConfig(Config):
# Defaults to 'false'. # Defaults to 'false'.
# #
#require_uppercase: true #require_uppercase: true
ui_auth:
# The number of milliseconds to allow a user-interactive authentication
# session to be active.
#
# This defaults to 0, meaning the user is queried for their credentials
# before every action, but this can be overridden to alow a single
# validation to be re-used. This weakens the protections afforded by
# the user-interactive authentication process, by allowing for multiple
# (and potentially different) operations to use the same validation session.
#
# Uncomment below to allow for credential validation to last for 15
# seconds.
#
#session_timeout: 15000
""" """

View file

@ -322,6 +322,22 @@ class EmailConfig(Config):
self.email_subjects = EmailSubjectConfig(**subjects) self.email_subjects = EmailSubjectConfig(**subjects)
# The invite client location should be a HTTP(S) URL or None.
self.invite_client_location = email_config.get("invite_client_location") or None
if self.invite_client_location:
if not isinstance(self.invite_client_location, str):
raise ConfigError(
"Config option email.invite_client_location must be type str"
)
if not (
self.invite_client_location.startswith("http://")
or self.invite_client_location.startswith("https://")
):
raise ConfigError(
"Config option email.invite_client_location must be a http or https URL",
path=("email", "invite_client_location"),
)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return ( return (
"""\ """\
@ -389,10 +405,15 @@ class EmailConfig(Config):
# #
#validation_token_lifetime: 15m #validation_token_lifetime: 15m
# Directory in which Synapse will try to find the template files below. # The web client location to direct users to during an invite. This is passed
# If not set, default templates from within the Synapse package will be used. # to the identity server as the org.matrix.web_client_location key. Defaults
# to unset, giving no guidance to the identity server.
# #
# Do not uncomment this setting unless you want to customise the templates. #invite_client_location: https://app.element.io
# Directory in which Synapse will try to find the template files below.
# If not set, or the files named below are not found within the template
# directory, default templates from within the Synapse package will be used.
# #
# Synapse will look for the following templates in this directory: # Synapse will look for the following templates in this directory:
# #

View file

@ -12,12 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from netaddr import IPSet from synapse.config._base import Config
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config from synapse.config._util import validate_config
@ -36,23 +33,6 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist: for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True self.federation_domain_whitelist[domain] = True
self.federation_ip_range_blacklist = config.get(
"federation_ip_range_blacklist", []
)
# Attempt to create an IPSet from the given ranges
try:
self.federation_ip_range_blacklist = IPSet(
self.federation_ip_range_blacklist
)
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
"Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
federation_metrics_domains = config.get("federation_metrics_domains") or [] federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config( validate_config(
_METRICS_FOR_DOMAINS_SCHEMA, _METRICS_FOR_DOMAINS_SCHEMA,
@ -76,27 +56,6 @@ class FederationConfig(Config):
# - nyc.example.com # - nyc.example.com
# - syd.example.com # - syd.example.com
# Prevent federation requests from being sent to the following
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
# As of Synapse v1.4.0 this option also affects any outbound requests to identity
# servers provided by user input.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
federation_ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
- '192.168.0.0/16'
- '100.64.0.0/10'
- '169.254.0.0/16'
- '::1/128'
- 'fe80::/64'
- 'fc00::/7'
# Report prometheus metrics on the age of PDUs being sent to and received from # Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound # the following domains. This can be used to give an idea of "delay" on inbound
# and outbound federation, though be aware that any delay can be due to problems # and outbound federation, though be aware that any delay can be due to problems

View file

@ -32,5 +32,5 @@ class GroupsConfig(Config):
# If enabled, non server admins can only create groups with local parts # If enabled, non server admins can only create groups with local parts
# starting with this prefix # starting with this prefix
# #
#group_creation_prefix: "unofficial/" #group_creation_prefix: "unofficial_"
""" """

View file

@ -18,6 +18,7 @@ from ._base import RootConfig
from .meow import MeowConfig from .meow import MeowConfig
from .api import ApiConfig from .api import ApiConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .auth import AuthConfig
from .cache import CacheConfig from .cache import CacheConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .cas import CasConfig from .cas import CasConfig
@ -31,7 +32,6 @@ from .key import KeyConfig
from .logger import LoggingConfig from .logger import LoggingConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .oidc_config import OIDCConfig from .oidc_config import OIDCConfig
from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig from .push import PushConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
@ -78,7 +78,7 @@ class HomeServerConfig(RootConfig):
CasConfig, CasConfig,
SSOConfig, SSOConfig,
JWTConfig, JWTConfig,
PasswordConfig, AuthConfig,
EmailConfig, EmailConfig,
PasswordAuthProviderConfig, PasswordAuthProviderConfig,
PushConfig, PushConfig,

View file

@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer # filter options, but care must when using e.g. MemoryHandler to buffer
# writes. # writes.
log_context_filter = LoggingContextFilter(request="") log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name}) log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory() old_factory = logging.getLogRecordFactory()

View file

@ -66,7 +66,7 @@ class OIDCConfig(Config):
( (
self.oidc_user_mapping_provider_class, self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config, self.oidc_user_mapping_provider_config,
) = load_module(ump_config) ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods # Ensure loaded user mapping module has defined all necessary methods
required_methods = [ required_methods = [
@ -203,9 +203,10 @@ class OIDCConfig(Config):
# * user: The claims returned by the UserInfo Endpoint and/or in the ID # * user: The claims returned by the UserInfo Endpoint and/or in the ID
# Token # Token
# #
# This must be configured if using the default mapping provider. # If this is not set, the user will be prompted to choose their
# own username.
# #
localpart_template: "{{{{ user.preferred_username }}}}" #localpart_template: "{{{{ user.preferred_username }}}}"
# Jinja2 template for the display name to set on first login. # Jinja2 template for the display name to set on first login.
# #

View file

@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers") or []) providers.extend(config.get("password_providers") or [])
for provider in providers: for i, provider in enumerate(providers):
mod_name = provider["module"] mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided # This is for backwards compat when the ldap auth provider resided
@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
mod_name = LDAP_PROVIDER mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module( (provider_class, provider_config) = load_module(
{"module": mod_name, "config": provider["config"]} {"module": mod_name, "config": provider["config"]},
("password_providers", "<item %i>" % i),
) )
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))

View file

@ -17,6 +17,9 @@ import os
from collections import namedtuple from collections import namedtuple
from typing import Dict, List from typing import Dict, List
from netaddr import IPSet
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -142,7 +145,7 @@ class ContentRepositoryConfig(Config):
# them to be started. # them to be started.
self.media_storage_providers = [] # type: List[tuple] self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers: for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to # We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend # expose FileStorageProviderBackend
if provider_config["module"] == "file_system": if provider_config["module"] == "file_system":
@ -151,7 +154,9 @@ class ContentRepositoryConfig(Config):
".FileStorageProviderBackend" ".FileStorageProviderBackend"
) )
provider_class, parsed_config = load_module(provider_config) provider_class, parsed_config = load_module(
provider_config, ("media_storage_providers", "<item %i>" % i)
)
wrapper_config = MediaStorageProviderConfig( wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False), provider_config.get("store_local", False),
@ -182,9 +187,6 @@ class ContentRepositoryConfig(Config):
"to work" "to work"
) )
# netaddr is a dependency for url_preview
from netaddr import IPSet
self.url_preview_ip_range_blacklist = IPSet( self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"] config["url_preview_ip_range_blacklist"]
) )
@ -213,6 +215,10 @@ class ContentRepositoryConfig(Config):
# strip final NL # strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
ip_range_blacklist = "\n".join(
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
)
return ( return (
r""" r"""
## Media Store ## ## Media Store ##
@ -283,15 +289,7 @@ class ContentRepositoryConfig(Config):
# you uncomment the following list as a starting point. # you uncomment the following list as a starting point.
# #
#url_preview_ip_range_blacklist: #url_preview_ip_range_blacklist:
# - '127.0.0.0/8' %(ip_range_blacklist)s
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
# - '::1/128'
# - 'fe80::/64'
# - 'fc00::/7'
# List of IP address CIDR ranges that the URL preview spider is allowed # List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist. # to access even if they are specified in url_preview_ip_range_blacklist.

View file

@ -180,7 +180,7 @@ class _RoomDirectoryRule:
self._alias_regex = glob_to_regex(alias) self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id) self._room_id_regex = glob_to_regex(room_id)
except Exception as e: except Exception as e:
raise ConfigError("Failed to parse glob into regex: %s", e) raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id, room_id, aliases): def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases. """Tests if this rule matches the given user_id, room_id and aliases.

View file

@ -125,7 +125,7 @@ class SAML2Config(Config):
( (
self.saml2_user_mapping_provider_class, self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config, self.saml2_user_mapping_provider_config,
) = load_module(ump_dict) ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods # Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module # Note parse_config() is already checked during the call to load_module

View file

@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set
import attr import attr
import yaml import yaml
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.endpoint import parse_and_validate_server_name
@ -39,6 +40,34 @@ logger = logging.Logger(__name__)
# in the list. # in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
DEFAULT_IP_RANGE_BLACKLIST = [
# Localhost
"127.0.0.0/8",
# Private networks.
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
# Carrier grade NAT.
"100.64.0.0/10",
# Address registry.
"192.0.0.0/24",
# Link-local networks.
"169.254.0.0/16",
# Testing networks.
"198.18.0.0/15",
"192.0.2.0/24",
"198.51.100.0/24",
"203.0.113.0/24",
# Multicast.
"224.0.0.0/4",
# Localhost
"::1/128",
# Link-local addresses.
"fe80::/10",
# Unique local addresses.
"fc00::/7",
]
DEFAULT_ROOM_VERSION = "6" DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = ( ROOM_COMPLEXITY_TOO_GREAT = (
@ -256,6 +285,38 @@ class ServerConfig(Config):
# due to resource constraints # due to resource constraints
self.admin_contact = config.get("admin_contact", None) self.admin_contact = config.get("admin_contact", None)
ip_range_blacklist = config.get(
"ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
)
# Attempt to create an IPSet from the given ranges
try:
self.ip_range_blacklist = IPSet(ip_range_blacklist)
except Exception as e:
raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
# Always blacklist 0.0.0.0, ::
self.ip_range_blacklist.update(["0.0.0.0", "::"])
try:
self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
except Exception as e:
raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
# The federation_ip_range_blacklist is used for backwards-compatibility
# and only applies to federation and identity servers. If it is not given,
# default to ip_range_blacklist.
federation_ip_range_blacklist = config.get(
"federation_ip_range_blacklist", ip_range_blacklist
)
try:
self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
except Exception as e:
raise ConfigError(
"Invalid range(s) provided in federation_ip_range_blacklist."
) from e
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/": if self.public_baseurl[-1] != "/":
self.public_baseurl += "/" self.public_baseurl += "/"
@ -561,6 +622,10 @@ class ServerConfig(Config):
def generate_config_section( def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
): ):
ip_range_blacklist = "\n".join(
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
)
_, bind_port = parse_and_validate_server_name(server_name) _, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None: if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400
@ -752,6 +817,33 @@ class ServerConfig(Config):
# #
#enable_search: false #enable_search: false
# Prevent outgoing requests from being sent to the following blacklisted IP address
# CIDR ranges. If this option is not specified then it defaults to private IP
# address ranges (see the example below).
#
# The blacklist applies to the outbound requests for federation, identity servers,
# push servers, and for checking key validity for third-party invite events.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
# This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
#
#ip_range_blacklist:
%(ip_range_blacklist)s
# List of IP address CIDR ranges that should be allowed for federation,
# identity servers, push servers, and for checking key validity for
# third-party invite events. This is useful for specifying exceptions to
# wide-ranging blacklisted target IP ranges - e.g. for communication with
# a push server only visible in your network.
#
# This whitelist overrides ip_range_blacklist and defaults to an empty
# list.
#
#ip_range_whitelist:
# - '192.168.1.1'
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
# #

View file

@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
# spam checker, and thus was simply a dictionary with module # spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking # and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary # to see if the option resolves to a dictionary
self.spam_checkers.append(load_module(spam_checkers)) self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
elif isinstance(spam_checkers, list): elif isinstance(spam_checkers, list):
for spam_checker in spam_checkers: for i, spam_checker in enumerate(spam_checkers):
config_path = ("spam_checker", "<item %i>" % i)
if not isinstance(spam_checker, dict): if not isinstance(spam_checker, dict):
raise ConfigError("spam_checker syntax is incorrect") raise ConfigError("expected a mapping", config_path)
self.spam_checkers.append(load_module(spam_checker)) self.spam_checkers.append(load_module(spam_checker, config_path))
else: else:
raise ConfigError("spam_checker syntax is incorrect") raise ConfigError("spam_checker syntax is incorrect")

View file

@ -93,11 +93,8 @@ class SSOConfig(Config):
# - https://my.custom.client/ # - https://my.custom.client/
# Directory in which Synapse will try to find the template files below. # Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used. # If not set, or the files named below are not found within the template
# # directory, default templates from within the Synapse package will be used.
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
# #
# Synapse will look for the following templates in this directory: # Synapse will look for the following templates in this directory:
# #

View file

@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
provider = config.get("third_party_event_rules", None) provider = config.get("third_party_event_rules", None)
if provider is not None: if provider is not None:
self.third_party_event_rules = load_module(provider) self.third_party_event_rules = load_module(
provider, ("third_party_event_rules",)
)
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\

View file

@ -85,6 +85,9 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint # The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port") self.worker_replication_http_port = config.get("worker_replication_http_port")
# The shared secret used for authentication when connecting to the main synapse.
self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@ -185,6 +188,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process. # data). If not provided this defaults to the main process.
# #
#run_background_tasks_on: worker1 #run_background_tasks_on: worker1
# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
""" """
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -227,7 +227,7 @@ class ConnectionVerifier:
# This code is based on twisted.internet.ssl.ClientTLSOptions. # This code is based on twisted.internet.ssl.ClientTLSOptions.
def __init__(self, hostname: bytes, verify_certs): def __init__(self, hostname: bytes, verify_certs: bool):
self._verify_certs = verify_certs self._verify_certs = verify_certs
_decoded = hostname.decode("ascii") _decoded = hostname.decode("ascii")

View file

@ -18,7 +18,7 @@
import collections.abc import collections.abc
import hashlib import hashlib
import logging import logging
from typing import Dict from typing import Any, Callable, Dict, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict from synapse.events.utils import prune_event, prune_event_dict
from synapse.types import JsonDict from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Hasher = Callable[[bytes], "hashlib._Hash"]
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
def check_event_content_hash(
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> bool:
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug( logger.debug(
@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash return message_hash_bytes == expected_hash
def compute_content_hash(event_dict, hash_algorithm): def compute_content_hash(
event_dict: Dict[str, Any], hash_algorithm: Hasher
) -> Tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the """Compute the content hash of an event, which is the hash of the
unredacted event. unredacted event.
Args: Args:
event_dict (dict): The unredacted event as a dict event_dict: The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event to hash the event
Returns: Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw A tuple of the name of hash and the hash as raw bytes.
bytes.
""" """
event_dict = dict(event_dict) event_dict = dict(event_dict)
event_dict.pop("age_ts", None) event_dict.pop("age_ts", None)
@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
return hashed.name, hashed.digest() return hashed.name, hashed.digest()
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): def compute_event_reference_hash(
event, hash_algorithm: Hasher = hashlib.sha256
) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted """Computes the event reference hash. This is the hash of the redacted
event. event.
Args: Args:
event (FrozenEvent) event
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event to hash the event
Returns: Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw A tuple of the name of hash and the hash as raw bytes.
bytes.
""" """
tmp_event = prune_event(event) tmp_event = prune_event(event)
event_dict = tmp_event.get_pdu_json() event_dict = tmp_event.get_pdu_json()
@ -156,7 +163,7 @@ def add_hashes_and_signatures(
event_dict: JsonDict, event_dict: JsonDict,
signature_name: str, signature_name: str,
signing_key: SigningKey, signing_key: SigningKey,
): ) -> None:
"""Add content hash and sign the event """Add content hash and sign the event
Args: Args:

View file

@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
import urllib import urllib
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr import attr
from signedjson.key import ( from signedjson.key import (
@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.config.key import TrustedKeyServer
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
make_deferred_yieldable, make_deferred_yieldable,
@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background, run_in_background,
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object. A request to verify a JSON object.
Attributes: Attributes:
server_name(str): The name of the server to verify against. server_name: The name of the server to verify against.
key_ids(set[str]): The set of key_ids to that could be used to verify the json_object: The JSON object to verify.
JSON object
json_object(dict): The JSON object to verify. minimum_valid_until_ts: time at which we require the signing key to
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care) be valid. (0 implies we don't care)
request_name: The name of the request.
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]): key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no a verify key has been fetched. The deferreds' callbacks are run with no
@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError. errbacks with an M_UNAUTHORIZED SynapseError.
""" """
server_name = attr.ib() server_name = attr.ib(type=str)
json_object = attr.ib() json_object = attr.ib(type=JsonDict)
minimum_valid_until_ts = attr.ib() minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib() request_name = attr.ib(type=str)
key_ids = attr.ib(init=False) key_ids = attr.ib(init=False, type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred)) key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name) self.key_ids = signature_ids(self.json_object, self.server_name)
@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring: class Keyring:
def __init__(self, hs, key_fetchers=None): def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
self.clock = hs.get_clock() self.clock = hs.get_clock()
if key_fetchers is None: if key_fetchers is None:
@ -112,22 +122,26 @@ class Keyring:
# completes. # completes.
# #
# These are regular, logcontext-agnostic Deferreds. # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server( def verify_json_for_server(
self, server_name, json_object, validity_time, request_name self,
): server_name: str,
json_object: JsonDict,
validity_time: int,
request_name: str,
) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server """Verify that a JSON object has been signed by a given server
Args: Args:
server_name (str): name of the server which must have signed this object server_name: name of the server which must have signed this object
json_object (dict): object to be checked json_object: object to be checked
validity_time (int): timestamp at which we require the signing key to validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care) be valid. (0 implies we don't care)
request_name (str): an identifier for this json object (eg, an event id) request_name: an identifier for this json object (eg, an event id)
for logging. for logging.
Returns: Returns:
@ -138,12 +152,14 @@ class Keyring:
requests = (req,) requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0]) return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as """Bulk verifies signatures of json objects, bulk fetching keys as
necessary. necessary.
Args: Args:
server_and_json (iterable[Tuple[str, dict, int, str]): server_and_json:
Iterable of (server_name, json_object, validity_time, request_name) Iterable of (server_name, json_object, validity_time, request_name)
tuples. tuples.
@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json for server_name, json_object, validity_time, request_name in server_and_json
) )
def _verify_objects(self, verify_requests): def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server """Does the work of verify_json_[objects_]for_server
Args: Args:
verify_requests (iterable[VerifyJsonRequest]): verify_requests: Iterable of verification requests.
Iterable of verification requests.
Returns: Returns:
List<Deferred[None]>: for each input item, a deferred indicating success List<Deferred[None]>: for each input item, a deferred indicating success
@ -182,7 +199,7 @@ class Keyring:
key_lookups = [] key_lookups = []
handle = preserve_fn(_handle_key_deferred) handle = preserve_fn(_handle_key_deferred)
def process(verify_request): def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list """Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which Adds a key request to key_lookups, and returns a deferred which
@ -222,18 +239,20 @@ class Keyring:
return results return results
async def _start_key_lookups(self, verify_requests): async def _start_key_lookups(
self, verify_requests: List[VerifyJsonRequest]
) -> None:
"""Sets off the key fetches for each verify request """Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved. Once each fetch completes, verify_request.key_ready will be resolved.
Args: Args:
verify_requests (List[VerifyJsonRequest]): verify_requests:
""" """
try: try:
# map from server name to a set of outstanding request ids # map from server name to a set of outstanding request ids
server_to_request_ids = {} server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests: for verify_request in verify_requests:
server_name = verify_request.server_name server_name = verify_request.server_name
@ -275,11 +294,11 @@ class Keyring:
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
async def wait_for_previous_lookups(self, server_names) -> None: async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_names (Iterable[str]): list of servers which we want to look up server_names: list of servers which we want to look up
Returns: Returns:
Resolves once all key lookups for the given servers have Resolves once all key lookups for the given servers have
@ -304,7 +323,7 @@ class Keyring:
loop_count += 1 loop_count += 1
def _get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with For each verify_request, verify_request.key_ready is called back with
@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found. with a SynapseError if none of the keys are found.
Args: Args:
verify_requests (list[VerifyJsonRequest]): list of verify requests verify_requests: list of verify requests
""" """
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations) run_in_background(do_iterations)
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): async def _attempt_key_fetches_with_fetcher(
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
):
"""Use a key fetcher to attempt to satisfy some key requests """Use a key fetcher to attempt to satisfy some key requests
Args: Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys fetcher: fetcher to use to fetch the keys
remaining_requests (set[VerifyJsonRequest]): outstanding key requests. remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list. Any successfully-completed requests will be removed from the list.
""" """
# dict[str, dict[str, int]]: keys to fetch. # The keys to fetch.
# server_name -> key_id -> min_valid_ts # server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict) missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests: for verify_request in remaining_requests:
# any completed requests should already have been removed # any completed requests should already have been removed
@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed) remaining_requests.difference_update(completed)
class KeyFetcher: class KeyFetcher(metaclass=abc.ABCMeta):
async def get_keys(self, keys_to_fetch): @abc.abstractmethod
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Map from server_name -> key_id -> FetchKeyResult
map from server_name -> key_id -> FetchKeyResult
""" """
raise NotImplementedError raise NotImplementedError
@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher): class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store""" """KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( key_ids_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items() for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys() for key_id in keys_for_server.keys()
) )
res = await self.store.get_server_verify_keys(keys_to_fetch) res = await self.store.get_server_verify_keys(key_ids_to_fetch)
keys = {} keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
return keys return keys
class BaseV2KeyFetcher: class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.get_config() self.config = hs.get_config()
async def process_v2_response(self, from_server, response_json, time_added_ms): async def process_v2_response(
self, from_server: str, response_json: JsonDict, time_added_ms: int
) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from This is used to parse either the entirety of the response from
@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query. to /_matrix/key/v2/query.
Args: Args:
from_server (str): the name of the server producing this result: either from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query. for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object response_json: the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json time_added_ms: the timestamp to record in server_keys_json
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object Map from key_id to result object
""" """
ts_valid_until_ms = response_json["valid_until_ts"] ts_valid_until_ms = response_json["valid_until_ts"]
@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher): class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers""" """KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers self.key_servers = self.config.key_servers
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
async def get_key(key_server): async def get_key(key_server: TrustedKeyServer) -> Dict:
try: try:
result = await self.get_server_verify_key_v2_indirect( return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server keys_to_fetch, key_server
) )
return result
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e "Key lookup failed from %r: %s", key_server.server_name, e
@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )
union_of_keys = {} union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results: for result in results:
for server_name, keys in result.items(): for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys) union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys return union_of_keys
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): async def get_server_verify_key_v2_indirect(
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts the keys to be fetched. server_name -> key_id -> min_valid_ts
key_server (synapse.config.key.TrustedKeyServer): notary server to query for key_server: notary server to query for the keys
the keys
Returns: Returns:
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map Map from server_name -> key_id -> FetchKeyResult
from server_name -> key_id -> FetchKeyResult
Raises: Raises:
KeyLookupError if there was an error processing the entire response from KeyLookupError if there was an error processing the entire response from
@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e: except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,)) raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys = {} keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
added_keys = [] added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
assert isinstance(query_response, dict)
for response in query_response["server_keys"]: for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter # do this first, so that we can give useful errors thereafter
server_name = response.get("server_name") server_name = response.get("server_name")
@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys return keys
def _validate_perspectives_response(self, key_server, response): def _validate_perspectives_response(
self, key_server: TrustedKeyServer, response: JsonDict
) -> None:
"""Optionally check the signature on the result of a /key/query request """Optionally check the signature on the result of a /key/query request
Args: Args:
key_server (synapse.config.key.TrustedKeyServer): the notary server that key_server: the notary server that produced this result
produced this result
response (dict): the json-decoded Server Keys response object response: the json-decoded Server Keys response object
""" """
perspective_name = key_server.server_name perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys perspective_keys = key_server.verify_keys
@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers""" """KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_federation_http_client()
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, iterable[str]]): keys_to_fetch:
the keys to be fetched. server_name -> key_ids the keys to be fetched. server_name -> key_ids
Returns: Returns:
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: Map from server_name -> key_id -> FetchKeyResult
map from server_name -> key_id -> FetchKeyResult
""" """
results = {} results = {}
async def get_key(key_to_fetch_item): async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item server_name, key_ids = key_to_fetch_item
try: try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items()) await yieldable_gather_results(get_key, keys_to_fetch.items())
return results return results
async def get_server_verify_key_v2_direct(self, server_name, key_ids): async def get_server_verify_key_v2_direct(
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, FetchKeyResult]:
""" """
Args: Args:
server_name (str): server_name:
key_ids (iterable[str]): key_ids:
Returns: Returns:
dict[str, FetchKeyResult]: map from key ID to lookup result Map from key ID to lookup result
Raises: Raises:
KeyLookupError if there was a problem making the lookup KeyLookupError if there was a problem making the lookup
""" """
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another. # we may have found this key as a side-effect of asking for another.
@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e: except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,)) raise KeyLookupError("Remote server returned an error: %s" % (e,))
assert isinstance(response, dict)
if response["server_name"] != server_name: if response["server_name"] != server_name:
raise KeyLookupError( raise KeyLookupError(
"Expected a response for server %r not %r" "Expected a response for server %r not %r"
@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys return keys
async def _handle_key_deferred(verify_request) -> None: async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification """Waits for the key to become available, and then performs a verification
Args: Args:
verify_request (VerifyJsonRequest): verify_request:
Raises: Raises:
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification

View file

@ -15,10 +15,11 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.events import synapse.events
@ -39,7 +40,9 @@ class SpamChecker:
else: else:
self.spam_checkers.append(module(config=config)) self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: async def check_event_for_spam(
self, event: "synapse.events.EventBase"
) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server. """Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if If the server considers an event spammy, then it will be rejected if
@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked event: the event to be checked
Returns: Returns:
True if the event is spammy. True or a string if the event is spammy. If a string is returned it
will be used as the error message returned to the user.
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.check_event_for_spam(event): if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True return True
return False return False
def user_may_invite( async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool: ) -> bool:
"""Checks if a given user may send an invite """Checks if a given user may send an invite
@ -75,14 +79,18 @@ class SpamChecker:
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if ( if (
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) await maybe_awaitable(
spam_checker.user_may_invite(
inviter_userid, invitee_userid, room_id
)
)
is False is False
): ):
return False return False
return True return True
def user_may_create_room(self, userid: str) -> bool: async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room """Checks if a given user may create a room
If this method returns false, the creation request will be rejected. If this method returns false, the creation request will be rejected.
@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False True if the user may create a room, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room(userid) is False: if (
await maybe_awaitable(spam_checker.user_may_create_room(userid))
is False
):
return False return False
return True return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias """Checks if a given user may create a room alias
If this method returns false, the association request will be rejected. If this method returns false, the association request will be rejected.
@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False True if the user may create a room alias, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room_alias(userid, room_alias) is False: if (
await maybe_awaitable(
spam_checker.user_may_create_room_alias(userid, room_alias)
)
is False
):
return False return False
return True return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool: async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory """Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected. If this method returns false, the publish request will be rejected.
@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False True if the user may publish the room, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_publish_room(userid, room_id) is False: if (
await maybe_awaitable(
spam_checker.user_may_publish_room(userid, room_id)
)
is False
):
return False return False
return True return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server. """Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in If the server considers a username spammy, then it will not be included in
@ -157,12 +178,12 @@ class SpamChecker:
if checker: if checker:
# Make a copy of the user profile object to ensure the spam checker # Make a copy of the user profile object to ensure the spam checker
# cannot modify it. # cannot modify it.
if checker(user_profile.copy()): if await maybe_awaitable(checker(user_profile.copy())):
return True return True
return False return False
def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid: Optional[dict], email_threepid: Optional[dict],
username: Optional[str], username: Optional[str],
@ -185,7 +206,9 @@ class SpamChecker:
# spam checker # spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None) checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker: if checker:
behaviour = checker(email_threepid, username, request_info) behaviour = await maybe_awaitable(
checker(email_threepid, username, request_info)
)
assert isinstance(behaviour, RegistrationBehaviour) assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW: if behaviour != RegistrationBehaviour.ALLOW:
return behaviour return behaviour

View file

@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context() ctx = current_context()
@defer.inlineCallbacks
def callback(_, pdu: EventBase): def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
@ -105,7 +106,11 @@ class FederationBase:
) )
return redacted_event return redacted_event
if self.spam_checker.check_event_for_spam(pdu): result = yield defer.ensureDeferred(
self.spam_checker.check_event_for_spam(pdu)
)
if result:
logger.warning( logger.warning(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.event_id,

View file

@ -845,7 +845,6 @@ class FederationHandlerRegistry:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()

View file

@ -35,7 +35,7 @@ class TransportLayerClient:
def __init__(self, hs): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.client = hs.get_http_client() self.client = hs.get_federation_http_client()
@log_function @log_function
def get_room_state_ids(self, destination, room_id, event_id): def get_room_state_ids(self, destination, room_id, event_id):

View file

@ -144,7 +144,7 @@ class Authenticator:
): ):
raise FederationDeniedError(origin) raise FederationDeniedError(origin)
if not json_request["signatures"]: if origin is None or not json_request["signatures"]:
raise NoAuthenticationError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED 401, "Missing Authorization headers", Codes.UNAUTHORIZED
) )
@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args: Args:
hs (synapse.server.HomeServer): homeserver hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register. servlet_groups (list[str], optional): List of servlet groups to register.

View file

@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler: class BaseHandler:
""" """
Common base class for the event handlers. Common base class for the event handlers.
Deprecated: new code should not use this. Instead, Handler classes should define the
fields they actually need. The utility methods should either be factored out to
standalone helper functions, or to different Handler classes.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):

View file

@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.events import FrozenEvent from synapse.events import EventBase
from synapse.types import RoomStreamToken, StateMap from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler): class AdminHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
async def get_whois(self, user): async def get_whois(self, user: UserID) -> JsonDict:
connections = [] connections = []
sessions = await self.store.get_user_ip_and_agents(user) sessions = await self.store.get_user_ip_and_agents(user)
@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret return ret
async def get_user(self, user): async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details""" """Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string()) ret = await self.store.get_user_by_id(user.to_string())
if ret: if ret:
@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids ret["threepids"] = threepids
return ret return ret
async def export_user_data(self, user_id, writer): async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer. """Write all data we have on the user to the given writer.
Args: Args:
user_id (str) user_id: The user ID to fetch data of.
writer (ExfiltrationWriter) writer: The writer to write to.
Returns: Returns:
Resolves when all data for a user has been written. Resolves when all data for a user has been written.
@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0) from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering) to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room # Events that we've processed in this room
written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then # We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track # write out the state at those events. We do this by keeping track
@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events # The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen # that have the unseen event in their prev_events, i.e. the unseen
# events "children". dict[str, set[str]] # events "children".
unseen_to_child_events = {} unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all* # We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most
@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
return writer.finished() return writer.finished()
class ExfiltrationWriter: class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data. """Interface used to specify how to write exported data.
""" """
def write_events(self, room_id: str, events: List[FrozenEvent]): @abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room. """Write a batch of events for a room.
""" """
pass raise NotImplementedError()
def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): @abc.abstractmethod
def write_state(
self, room_id: str, event_id: str, state: StateMap[EventBase]
) -> None:
"""Write the state at the given event in the room. """Write the state at the given event in the room.
This only gets called for backward extremities rather than for each This only gets called for backward extremities rather than for each
event. event.
""" """
pass raise NotImplementedError()
def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): @abc.abstractmethod
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[dict]
) -> None:
"""Write an invite for the room, with associated invite state. """Write an invite for the room, with associated invite state.
Args: Args:
room_id room_id: The room ID the invite is for.
event event: The invite event.
state: A subset of the state at the state: A subset of the state at the invite, with a subset of the
invite, with a subset of the event keys (type, state_key event keys (type, state_key content and sender).
content and sender)
""" """
raise NotImplementedError()
def finished(self): @abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written. """Called when all data has successfully been exported and written.
This functions return value is passed to the caller of This functions return value is passed to the caller of
`export_user_data`. `export_user_data`.
""" """
pass raise NotImplementedError()

View file

@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import time import time
import unicodedata import unicodedata
@ -22,6 +21,7 @@ import urllib.parse
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -36,6 +36,8 @@ import attr
import bcrypt import bcrypt
import pymacaroons import pymacaroons
from twisted.web.http import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -56,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -193,39 +196,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
self._sso_enabled = ( self._password_localdb_enabled = hs.config.password_localdb_enabled
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
# start out by assuming PASSWORD is enabled; we will remove it later if not. # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = [] login_types = set()
if hs.config.password_localdb_enabled: if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD) login_types.add(LoginType.PASSWORD)
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"): login_types.update(provider.get_supported_login_types().keys())
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
if not self._password_enabled: if not self._password_enabled:
login_types.discard(LoginType.PASSWORD)
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD) login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
ui_auth_types = login_types.copy()
if self._sso_enabled:
ui_auth_types.append(LoginType.SSO)
self._supported_ui_auth_types = ui_auth_types
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
@ -235,6 +226,9 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
) )
# The number of seconds to keep a UI auth session active.
self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
# Ratelimitier for failed /login attempts # Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter( self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(), clock=hs.get_clock(),
@ -292,7 +286,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any], request_body: Dict[str, Any],
clientip: str, clientip: str,
description: str, description: str,
) -> Tuple[dict, str]: ) -> Tuple[dict, Optional[str]]:
""" """
Checks that the user is who they claim to be, via a UI auth. Checks that the user is who they claim to be, via a UI auth.
@ -319,7 +313,8 @@ class AuthHandler(BaseHandler):
have been given only in a previous call). have been given only in a previous call).
'session_id' is the ID of this session, either passed in by the 'session_id' is the ID of this session, either passed in by the
client or assigned by this call client or assigned by this call. This is None if UI auth was
skipped (by re-using a previous validation).
Raises: Raises:
InteractiveAuthIncompleteError if the client has not yet completed InteractiveAuthIncompleteError if the client has not yet completed
@ -333,13 +328,26 @@ class AuthHandler(BaseHandler):
""" """
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
)
if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
# Return the input parameters, minus the auth key, which matches
# the logic in check_ui_auth.
request_body.pop("auth", None)
return request_body, None
user_id = requester.user.to_string() user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts # Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows # build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types] supported_ui_auth_types = await self._get_available_ui_auth_types(
requester.user
)
flows = [[login_type] for login_type in supported_ui_auth_types]
try: try:
result, params, session_id = await self.check_ui_auth( result, params, session_id = await self.check_ui_auth(
@ -351,7 +359,7 @@ class AuthHandler(BaseHandler):
raise raise
# find the completed login type # find the completed login type
for login_type in self._supported_ui_auth_types: for login_type in supported_ui_auth_types:
if login_type not in result: if login_type not in result:
continue continue
@ -365,8 +373,46 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth") raise AuthError(403, "Invalid auth")
# Note that the access token has been validated.
await self.store.update_access_token_last_validated(requester.access_token_id)
return params, session_id return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
if self._password_localdb_enabled and self._password_enabled:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
if password_hash:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
def get_enabled_auth_types(self): def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types """Return the enabled user-interactive authentication types
@ -423,13 +469,10 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows. all the stages in any of the permitted flows.
""" """
authdict = None
sid = None # type: Optional[str] sid = None # type: Optional[str]
if clientdict and "auth" in clientdict: authdict = clientdict.pop("auth", {})
authdict = clientdict["auth"] if "session" in authdict:
del clientdict["auth"] sid = authdict["session"]
if "session" in authdict:
sid = authdict["session"]
# Convert the URI and method to strings. # Convert the URI and method to strings.
uri = request.uri.decode("utf-8") uri = request.uri.decode("utf-8")
@ -534,6 +577,8 @@ class AuthHandler(BaseHandler):
creds = await self.store.get_completed_ui_auth_stages(session.session_id) creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows: for f in flows:
# If all the required credentials have been supplied, the user has
# successfully completed the UI auth process!
if len(set(f) - set(creds)) == 0: if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can # it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log # include the password in the case of registering, so only log
@ -709,6 +754,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str], device_id: Optional[str],
valid_until_ms: Optional[int], valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None, puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
) -> str: ) -> str:
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -725,6 +771,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID) we should always have a device ID)
valid_until_ms: when the token is valid until. None for valid_until_ms: when the token is valid until. None for
no expiry. no expiry.
is_appservice_ghost: Whether the user is an application ghost user
Returns: Returns:
The access token for the user's session. The access token for the user's session.
Raises: Raises:
@ -745,7 +792,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
) )
await self.auth.check_auth_blocking(user_id) if (
not is_appservice_ghost
or self.hs.config.appservice.track_appservice_user_ips
):
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
@ -831,7 +882,7 @@ class AuthHandler(BaseHandler):
async def validate_login( async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False, self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't Also used by the user-interactive auth flow to validate auth types which don't
@ -974,7 +1025,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login( async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any], self, username: str, login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login """Helper for validate_login
Handles login, once we've mapped 3pids onto userids Handles login, once we've mapped 3pids onto userids
@ -1029,7 +1080,7 @@ class AuthHandler(BaseHandler):
if result: if result:
return result return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True known_login_type = True
# we've already checked that there is a (valid) password field # we've already checked that there is a (valid) password field
@ -1052,7 +1103,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid( async def check_password_provider_3pid(
self, medium: str, address: str, password: str self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:
@ -1303,15 +1354,14 @@ class AuthHandler(BaseHandler):
) )
async def complete_sso_ui_auth( async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest, self, registered_user_id: str, session_id: str, request: Request,
): ):
"""Having figured out a mxid for this user, complete the HTTP request """Having figured out a mxid for this user, complete the HTTP request
Args: Args:
registered_user_id: The registered user ID to complete SSO login for. registered_user_id: The registered user ID to complete SSO login for.
session_id: The ID of the user-interactive auth session.
request: The request to complete. request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
""" """
# Mark the stage of the authentication as successful. # Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure # Save the user who authenticated with SSO, this will be used to ensure
@ -1327,7 +1377,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login( async def complete_sso_login(
self, self,
registered_user_id: str, registered_user_id: str,
request: SynapseRequest, request: Request,
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
): ):
@ -1355,7 +1405,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login( def _complete_sso_login(
self, self,
registered_user_id: str, registered_user_id: str,
request: SynapseRequest, request: Request,
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
): ):
@ -1609,6 +1659,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,) await maybe_awaitable(
if inspect.isawaitable(result): g(user_id=user_id, device_id=device_id, access_token=access_token,)
await result )

View file

@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
import attr
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError from synapse.api.errors import HttpResponseException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
@ -29,6 +32,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""
def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description
def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
class CasHandler: class CasHandler:
""" """
Utility class for to handle the response from a CAS SSO service. Utility class for to handle the response from a CAS SSO service.
@ -40,6 +63,7 @@ class CasHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._hostname = hs.hostname self._hostname = hs.hostname
self._store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
@ -50,6 +74,11 @@ class CasHandler:
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
# identifier for the external_ids table
self._auth_provider_id = "cas"
self._sso_handler = hs.get_sso_handler()
def _build_service_param(self, args: Dict[str, str]) -> str: def _build_service_param(self, args: Dict[str, str]) -> str:
""" """
Generates a value to use as the "service" parameter when redirecting or Generates a value to use as the "service" parameter when redirecting or
@ -69,14 +98,20 @@ class CasHandler:
async def _validate_ticket( async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str] self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]: ) -> CasResponse:
""" """
Validate a CAS ticket with the server, parse the response, and return the user and display name. Validate a CAS ticket with the server, and return the parsed the response.
Args: Args:
ticket: The CAS ticket from the client. ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL. service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`. Should be the same as those passed to `get_redirect_url`.
Raises:
CasError: If there's an error parsing the CAS response.
Returns:
The parsed CAS response.
""" """
uri = self._cas_server_url + "/proxyValidate" uri = self._cas_server_url + "/proxyValidate"
args = { args = {
@ -89,66 +124,65 @@ class CasHandler:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data
body = pde.response body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
raise CasError("server_error", description) from e
user, attributes = self._parse_cas_response(body) return self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)
for required_attribute, required_value in self._cas_required_attributes.items(): def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return user, displayname
def _parse_cas_response(
self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
""" """
Retrieve the user and other parameters from the CAS response. Retrieve the user and other parameters from the CAS response.
Args: Args:
cas_response_body: The response from the CAS query. cas_response_body: The response from the CAS query.
Raises:
CasError: If there's an error parsing the CAS response.
Returns: Returns:
A tuple of the user and a mapping of other attributes. The parsed CAS response.
""" """
# Ensure the response is valid.
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise CasError(
"missing_service_response",
"root of CAS response is not serviceResponse",
)
success = root[0].tag.endswith("authenticationSuccess")
if not success:
raise CasError("unsucessful_response", "Unsuccessful CAS response")
# Iterate through the nodes and pull out the user and any extra attributes.
user = None user = None
attributes = {} attributes = {}
try: for child in root[0]:
root = ET.fromstring(cas_response_body) if child.tag.endswith("user"):
if not root.tag.endswith("serviceResponse"): user = child.text
raise Exception("root of CAS response is not serviceResponse") if child.tag.endswith("attributes"):
success = root[0].tag.endswith("authenticationSuccess") for attribute in child:
for child in root[0]: # ElementTree library expands the namespace in
if child.tag.endswith("user"): # attribute tags to the full URL of the namespace.
user = child.text # We don't care about namespace here and it will always
if child.tag.endswith("attributes"): # be encased in curly braces, so we remove them.
for attribute in child: tag = attribute.tag
# ElementTree library expands the namespace in if "}" in tag:
# attribute tags to the full URL of the namespace. tag = tag.split("}")[1]
# We don't care about namespace here and it will always attributes[tag] = attribute.text
# be encased in curly braces, so we remove them.
tag = attribute.tag # Ensure a user was found.
if "}" in tag: if user is None:
tag = tag.split("}")[1] raise CasError("no_user", "CAS response does not contain user")
attributes[tag] = attribute.text
if user is None: return CasResponse(user, attributes)
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
def get_redirect_url(self, service_args: Dict[str, str]) -> str: def get_redirect_url(self, service_args: Dict[str, str]) -> str:
""" """
@ -201,59 +235,150 @@ class CasHandler:
args["redirectUrl"] = client_redirect_url args["redirectUrl"] = client_redirect_url
if session: if session:
args["session"] = session args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
# Pull out the user-agent and IP from the request. try:
user_agent = request.get_user_agent("") cas_response = await self._validate_ticket(ticket, args)
ip_address = self.hs.get_ip_from_request(request) except CasError as e:
logger.exception("Could not validate ticket")
self._sso_handler.render_error(request, e.error, e.error_description, 401)
return
# Get the matrix ID from the CAS username. await self._handle_cas_response(
user_id = await self._map_cas_user_to_matrix_user( request, cas_response, client_redirect_url, session
username, user_display_name, user_agent, ip_address
) )
if session: async def _handle_cas_response(
await self._auth_handler.complete_sso_ui_auth(
user_id, session, request,
)
else:
# If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
async def _map_cas_user_to_matrix_user(
self, self,
remote_user_id: str, request: SynapseRequest,
display_name: Optional[str], cas_response: CasResponse,
user_agent: str, client_redirect_url: Optional[str],
ip_address: str, session: Optional[str],
) -> str: ) -> None:
""" """Handle a CAS response to a ticket request.
Given a CAS username, retrieve the user ID for it and possibly register the user.
Assumes that the response has been validated. Maps the user onto an MXID,
registering them if necessary, and returns a response to the browser.
Args: Args:
remote_user_id: The username from the CAS response. request: the incoming request from the browser. We'll respond to it with an
display_name: The display name from the CAS response. HTML page or a redirect
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns: cas_response: The parsed CAS response.
The user ID associated with this response.
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
session: The session parameter from the `/cas/ticket` HTTP request, if given.
This should be the UI Auth session id.
""" """
localpart = map_username_to_mxid_localpart(remote_user_id) # first check if we're doing a UIA
user_id = UserID(localpart, self._hostname).to_string() if session:
registered_user_id = await self._auth_handler.check_user_exists(user_id) return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, cas_response.username, session, request,
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
user_agent_ips=[(user_agent, ip_address)],
) )
return registered_user_id # otherwise, we're handling a login request.
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Call the mapper to register/login the user
# If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url is not None
try:
await self._complete_cas_login(cas_response, request, client_redirect_url)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
async def _complete_cas_login(
self,
cas_response: CasResponse,
request: SynapseRequest,
client_redirect_url: str,
) -> None:
"""
Given a CAS response, complete the login flow
Retrieves the remote user ID, registers the user if necessary, and serves
a redirect back to the client with a login-token.
Args:
cas_response: The parsed CAS response.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
# Note that CAS does not support a mapping provider, so the logic is hard-coded.
localpart = map_username_to_mxid_localpart(cas_response.username)
async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Map from CAS attributes to user attributes.
"""
# Due to the grandfathering logic matching any previously registered
# mxids it isn't expected for there to be any failures.
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
display_name = cas_response.attributes.get(
self._cas_displayname_attribute, None
)
return UserAttributes(localpart=localpart, display_name=display_name)
async def grandfather_existing_users() -> Optional[str]:
# Since CAS did not always use the user_external_ids table, always
# to attempt to map to existing users.
user_id = UserID(localpart, self._hostname).to_string()
logger.debug(
"Looking for existing account based on mapped %s", user_id,
)
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
return registered_user_id
return None
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
cas_response.username,
request,
client_redirect_url,
cas_response_to_user_attributes,
grandfather_existing_users,
)

View file

@ -135,7 +135,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it" 403, "You must be in the room to create an alias for it"
) )
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): if not await self.spam_checker.user_may_create_room_alias(
user_id, room_alias
):
raise AuthError(403, "This user is not permitted to create this alias") raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed( if not self.config.is_alias_creation_allowed(
@ -411,7 +413,7 @@ class DirectoryHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id): if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError( raise AuthError(
403, "This user is not permitted to publish rooms to the room list" 403, "This user is not permitted to publish rooms to the room list"
) )

View file

@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler() self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config self.config = hs.config
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler() self._replication = hs.get_replication_data_handler()
@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites: if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites") raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite( if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id event.sender, event.state_key, event.room_id
): ):
raise SynapseError( raise SynapseError(

View file

@ -29,7 +29,7 @@ def _create_rerouter(func_name):
async def f(self, group_id, *args, **kwargs): async def f(self, group_id, *args, **kwargs):
if not GroupID.is_valid(group_id): if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s was not legal group ID" % (group_id,)) raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
return await getattr(self.groups_server_handler, func_name)( return await getattr(self.groups_server_handler, func_name)(

View file

@ -46,15 +46,17 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs) self.http_client = SimpleHttpClient(hs)
# We create a blacklisting instance of SimpleHttpClient for contacting identity # An HTTP client for contacting identity servers specified by clients.
# servers specified by clients
self.blacklisting_http_client = SimpleHttpClient( self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist hs, ip_blacklist=hs.config.federation_ip_range_blacklist
) )
self.federation_http_client = hs.get_http_client() self.federation_http_client = hs.get_federation_http_client()
self.hs = hs self.hs = hs
self._web_client_location = hs.config.invite_client_location
async def threepid_from_creds( async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str] self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]: ) -> Optional[JsonDict]:
@ -803,6 +805,9 @@ class IdentityHandler(BaseHandler):
"sender_display_name": inviter_display_name, "sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url, "sender_avatar_url": inviter_avatar_url,
} }
# If a custom web client location is available, include it in the request.
if self._web_client_location:
invite_config["org.matrix.web_client_location"] = self._web_client_location
# Add the identity service access token to the JSON body and use the v2 # Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present # Identity Service endpoints if id_access_token is present

View file

@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str, member_event_id: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = await self.state_store.get_state_for_event(member_event_id)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:

View file

@ -746,7 +746,7 @@ class EventCreationHandler:
event.sender, event.sender,
) )
spam_error = self.spam_checker.check_event_for_spam(event) spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error: if spam_error:
if not isinstance(spam_error, str): if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"
@ -1264,7 +1264,7 @@ class EventCreationHandler:
event, context = await self.create_event( event, context = await self.create_event(
requester, requester,
{ {
"type": "org.matrix.dummy_event", "type": EventTypes.Dummy,
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,

View file

@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
@ -674,38 +672,29 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e)) self._sso_handler.render_error(request, "invalid_token", str(e))
return return
# Pull out the user-agent and IP from the request. # first check if we're doing a UIA
user_agent = request.get_user_agent("") if ui_auth_session_id:
ip_address = self.hs.get_ip_from_request(request) try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
logger.exception("Could not extract remote user id")
self._sso_handler.render_error(request, "mapping_error", str(e))
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
)
# otherwise, it's a login
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
user_id = await self._map_userinfo_to_user( await self._complete_oidc_login(
userinfo, token, user_agent, ip_address userinfo, token, request, client_redirect_url
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token( def _generate_oidc_session_token(
self, self,
@ -828,10 +817,14 @@ class OidcHandler(BaseHandler):
now = self.clock.time_msec() now = self.clock.time_msec()
return now < expiry return now < expiry
async def _map_userinfo_to_user( async def _complete_oidc_login(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str self,
) -> str: userinfo: UserInfo,
"""Maps a UserInfo object to a mxid. token: Token,
request: SynapseRequest,
client_redirect_url: str,
) -> None:
"""Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`. is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@ -843,27 +836,23 @@ class OidcHandler(BaseHandler):
If a user already exists with the mxid we've mapped and allow_existing_users If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception. is disabled, raise an exception.
Otherwise, render a redirect back to the client_redirect_url with a loginToken.
Args: Args:
userinfo: an object representing the user userinfo: an object representing the user
token: a dict with the tokens obtained from the provider token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request. request: The request to respond to
ip_address: The IP address of the client making the request. client_redirect_url: The redirect URL passed in by the client.
Raises: Raises:
MappingException: if there was an error while mapping some properties MappingException: if there was an error while mapping some properties
Returns:
The mxid of the user
""" """
try: try:
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e: except Exception as e:
raise MappingException( raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,) "Failed to extract subject from OIDC response: %s" % (e,)
) )
# Some OIDC providers use integer IDs, but Synapse expects external IDs
# to be strings.
remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we # Older mapping providers don't accept the `failures` argument, so we
# try and detect support. # try and detect support.
@ -924,18 +913,41 @@ class OidcHandler(BaseHandler):
return None return None
return await self._sso_handler.get_mxid_from_sso( # Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self._auth_provider_id,
remote_user_id, remote_user_id,
user_agent, request,
ip_address, client_redirect_url,
oidc_response_to_user_attributes, oidc_response_to_user_attributes,
grandfather_existing_users, grandfather_existing_users,
extra_attributes,
) )
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
"""Extract the unique remote id from an OIDC UserInfo block
Args:
userinfo: An object representing the user given by the OIDC provider
Returns:
remote user id
"""
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
# Some OIDC providers use integer IDs, but Synapse expects external IDs
# to be strings.
return str(remote_user_id)
UserAttributeDict = TypedDict( UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]} "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
) )
C = TypeVar("C") C = TypeVar("C")
@ -1016,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
@attr.s @attr.s
class JinjaOidcMappingConfig: class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str subject_claim = attr.ib(type=str)
localpart_template = attr.ib() # type: Template localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib() # type: Optional[Template] display_name_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib() # type: Dict[str, Template] extra_attributes = attr.ib(type=Dict[str, Template])
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@ -1035,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig: def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub") subject_claim = config.get("subject_claim", "sub")
if "localpart_template" not in config: localpart_template = None # type: Optional[Template]
raise ConfigError( if "localpart_template" in config:
"missing key: oidc_config.user_mapping_provider.config.localpart_template" try:
) localpart_template = env.from_string(config["localpart_template"])
except Exception as e:
try: raise ConfigError(
localpart_template = env.from_string(config["localpart_template"]) "invalid jinja template", path=["localpart_template"]
except Exception as e: ) from e
raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
% (e,)
)
display_name_template = None # type: Optional[Template] display_name_template = None # type: Optional[Template]
if "display_name_template" in config: if "display_name_template" in config:
@ -1054,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name_template = env.from_string(config["display_name_template"]) display_name_template = env.from_string(config["display_name_template"])
except Exception as e: except Exception as e:
raise ConfigError( raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" "invalid jinja template", path=["display_name_template"]
% (e,) ) from e
)
extra_attributes = {} # type Dict[str, Template] extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config: if "extra_attributes" in config:
extra_attributes_config = config.get("extra_attributes") or {} extra_attributes_config = config.get("extra_attributes") or {}
if not isinstance(extra_attributes_config, dict): if not isinstance(extra_attributes_config, dict):
raise ConfigError( raise ConfigError("must be a dict", path=["extra_attributes"])
"oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
)
for key, value in extra_attributes_config.items(): for key, value in extra_attributes_config.items():
try: try:
extra_attributes[key] = env.from_string(value) extra_attributes[key] = env.from_string(value)
except Exception as e: except Exception as e:
raise ConfigError( raise ConfigError(
"invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" "invalid jinja template", path=["extra_attributes", key]
% (key, e) ) from e
)
return JinjaOidcMappingConfig( return JinjaOidcMappingConfig(
subject_claim=subject_claim, subject_claim=subject_claim,
@ -1088,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
async def map_user_attributes( async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict: ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip() localpart = None
# Ensure only valid characters are included in the MXID. if self._config.localpart_template:
localpart = map_username_to_mxid_localpart(localpart) localpart = self._config.localpart_template.render(user=userinfo).strip()
# Append suffix integer if last call to this function failed to produce # Ensure only valid characters are included in the MXID.
# a usable mxid. localpart = map_username_to_mxid_localpart(localpart)
localpart += str(failures) if failures else ""
# Append suffix integer if last call to this function failed to produce
# a usable mxid.
localpart += str(failures) if failures else ""
display_name = None # type: Optional[str] display_name = None # type: Optional[str]
if self._config.display_name_template is not None: if self._config.display_name_template is not None:

View file

@ -13,18 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler): class ReceiptsHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
@ -37,7 +39,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
async def _received_remote_receipt(self, origin, content): async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS. """Called when we receive an EDU of type m.receipt from a remote HS.
""" """
receipts = [] receipts = []
@ -64,11 +66,11 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts) await self._handle_new_receipts(receipts)
async def _handle_new_receipts(self, receipts): async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier. """Takes a list of receipts, stores them and informs the notifier.
""" """
min_batch_id = None min_batch_id = None # type: Optional[int]
max_batch_id = None max_batch_id = None # type: Optional[int]
for receipt in receipts: for receipt in receipts:
res = await self.store.insert_receipt( res = await self.store.insert_receipt(
@ -90,7 +92,8 @@ class ReceiptsHandler(BaseHandler):
if max_batch_id is None or max_persisted_id > max_batch_id: if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id max_batch_id = max_persisted_id
if min_batch_id is None: # Either both of these should be None or neither.
if min_batch_id is None or max_batch_id is None:
# no new receipts # no new receipts
return False return False
@ -98,15 +101,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
await maybe_awaitable( await self.hs.get_pusherpool().on_new_receipts(
self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids
min_batch_id, max_batch_id, affected_room_ids
)
) )
return True return True
async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str
) -> None:
"""Called when a client tells us a local user has read up to the given """Called when a client tells us a local user has read up to the given
event_id in the room. event_id in the room.
""" """
@ -126,10 +129,12 @@ class ReceiptsHandler(BaseHandler):
class ReceiptEventSource: class ReceiptEventSource:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def get_new_events(self, from_key, room_ids, **kwargs): async def get_new_events(
self, from_key: int, room_ids: List[str], **kwargs
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key) from_key = int(from_key)
to_key = self.get_current_key() to_key = self.get_current_key()
@ -174,5 +179,5 @@ class ReceiptEventSource:
return (events, to_key) return (events, to_key)
def get_current_key(self, direction="f"): def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()

View file

@ -192,7 +192,7 @@ class RegistrationHandler(BaseHandler):
""" """
self.check_registration_ratelimit(address) self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam( result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [], threepid, localpart, user_agent_ips or [],
) )
@ -637,6 +637,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str], device_id: Optional[str],
initial_display_name: Optional[str], initial_display_name: Optional[str],
is_guest: bool = False, is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token. """Register a device for a user and generate an access token.
@ -658,6 +659,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id, device_id=device_id,
initial_display_name=initial_display_name, initial_display_name=initial_display_name,
is_guest=is_guest, is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
) )
return r["device_id"], r["access_token"] return r["device_id"], r["access_token"]
@ -679,7 +681,10 @@ class RegistrationHandler(BaseHandler):
) )
else: else:
access_token = await self._auth_handler.get_access_token_for_user_id( access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
) )
return (registered_device_id, access_token) return (registered_device_id, access_token)

View file

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
HistoryVisibility,
JoinRules, JoinRules,
Membership, Membership,
RoomCreationPreset, RoomCreationPreset,
@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler):
self._presets_dict = { self._presets_dict = {
RoomCreationPreset.PRIVATE_CHAT: { RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "shared", "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
"guest_can_join": True, "guest_can_join": True,
"power_level_content_override": {"invite": 0}, "power_level_content_override": {"invite": 0},
}, },
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "shared", "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": True, "original_invitees_have_ops": True,
"guest_can_join": True, "guest_can_join": True,
"power_level_content_override": {"invite": 0}, "power_level_content_override": {"invite": 0},
}, },
RoomCreationPreset.PUBLIC_CHAT: { RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC, "join_rules": JoinRules.PUBLIC,
"history_visibility": "shared", "history_visibility": HistoryVisibility.SHARED,
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
"guest_can_join": False, "guest_can_join": False,
"power_level_content_override": {}, "power_level_content_override": {},
@ -358,7 +359,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.spam_checker.user_may_create_room(user_id): if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
creation_content = { creation_content = {
@ -440,6 +441,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[], invite_list=[],
initial_state=initial_state, initial_state=initial_state,
creation_content=creation_content, creation_content=creation_content,
ratelimit=False,
) )
# Transfer membership events # Transfer membership events
@ -608,7 +610,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN 403, "You are not permitted to create rooms", Codes.FORBIDDEN
) )
if not is_requester_admin and not self.spam_checker.user_may_create_room( if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id user_id
): ):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
@ -747,6 +749,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias, room_alias=room_alias,
power_level_content_override=power_level_content_override, power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile, creator_join_profile=creator_join_profile,
ratelimit=ratelimit,
) )
if "name" in config: if "name" in config:
@ -850,6 +853,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None, room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None, power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None,
ratelimit: bool = True,
) -> int: ) -> int:
"""Sends the initial events into a new room. """Sends the initial events into a new room.
@ -896,7 +900,7 @@ class RoomCreationHandler(BaseHandler):
creator.user, creator.user,
room_id, room_id,
"join", "join",
ratelimit=False, ratelimit=ratelimit,
content=creator_join_profile, content=creator_join_profile,
) )

View file

@ -15,19 +15,22 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Optional, Tuple
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list") self.response_cache = ResponseCache(
hs, "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache( self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000 hs, "remote_room_list", timeout_ms=30 * 1000
) ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list( async def get_local_public_room_list(
self, self,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
network_tuple=EMPTY_THIRD_PARTY_ID, network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation=False, from_federation: bool = False,
): ) -> JsonDict:
"""Generate a local public room list. """Generate a local public room list.
There are multiple different lists: the main one plus one per third There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all. party network. A client can ask for a specific list or to return all.
Args: Args:
limit (int|None) limit
since_token (str|None) since_token
search_filter (dict|None) search_filter
network_tuple (ThirdPartyInstanceID): Which public list to use. network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
from_federation (bool): true iff the request comes from the federation from_federation: true iff the request comes from the federation API
API
""" """
if not self.enable_room_list_search: if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0} return {"chunk": [], "total_room_count_estimate": 0}
@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
self, self,
limit: Optional[int] = None, limit: Optional[int] = None,
since_token: Optional[str] = None, since_token: Optional[str] = None,
search_filter: Optional[Dict] = None, search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False, from_federation: bool = False,
) -> Dict[str, Any]: ) -> JsonDict:
"""Generate a public room list. """Generate a public room list.
Args: Args:
limit: Maximum amount of rooms to return. limit: Maximum amount of rooms to return.
@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
if since_token: if since_token:
batch_token = RoomListNextBatch.from_token(since_token) batch_token = RoomListNextBatch.from_token(since_token)
bounds = (batch_token.last_joined_members, batch_token.last_room_id) bounds = (
batch_token.last_joined_members,
batch_token.last_room_id,
) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward forwards = batch_token.direction_is_forward
has_batch_token = True
else: else:
batch_token = None
bounds = None bounds = None
forwards = True forwards = True
has_batch_token = False
# we request one more than wanted to see if there are more pages to come # we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None probing_limit = limit + 1 if limit is not None else None
@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler):
"canonical_alias": room["canonical_alias"], "canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"], "num_joined_members": room["joined_members"],
"avatar_url": room["avatar"], "avatar_url": room["avatar"],
"world_readable": room["history_visibility"] == "world_readable", "world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join", "guest_can_join": room["guest_access"] == "can_join",
} }
@ -168,7 +177,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results] results = [build_room_entry(r) for r in results]
response = {} response = {} # type: JsonDict
num_results = len(results) num_results = len(results)
if limit is not None: if limit is not None:
more_to_come = num_results == probing_limit more_to_come = num_results == probing_limit
@ -186,7 +195,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0] initial_entry = results[0]
if forwards: if forwards:
if batch_token: if has_batch_token:
# If there was a token given then we assume that there # If there was a token given then we assume that there
# must be previous results. # must be previous results.
response["prev_batch"] = RoomListNextBatch( response["prev_batch"] = RoomListNextBatch(
@ -202,7 +211,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
else: else:
if batch_token: if has_batch_token:
response["next_batch"] = RoomListNextBatch( response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"], last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"], last_room_id=final_entry["room_id"],
@ -292,7 +301,7 @@ class RoomListHandler(BaseHandler):
return None return None
# Return whether this room is open to federation users or not # Return whether this room is open to federation users or not
create_event = current_state.get((EventTypes.Create, "")) create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True) result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, "")) name_event = current_state.get((EventTypes.Name, ""))
@ -317,7 +326,7 @@ class RoomListHandler(BaseHandler):
visibility = None visibility = None
if visibility_event: if visibility_event:
visibility = visibility_event.content.get("history_visibility", None) visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable" result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
guest_event = current_state.get((EventTypes.GuestAccess, "")) guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None guest = None
@ -335,13 +344,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list( async def get_remote_public_room_list(
self, self,
server_name, server_name: str,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
if not self.enable_room_list_search: if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0} return {"chunk": [], "total_room_count_estimate": 0}
@ -398,13 +407,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached( async def _get_remote_list_cached(
self, self,
server_name, server_name: str,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
repl_layer = self.hs.get_federation_client() repl_layer = self.hs.get_federation_client()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
@ -455,24 +464,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod @classmethod
def from_token(cls, token): def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False) decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch( return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
) )
def to_token(self): def to_token(self) -> str:
return encode_base64( return encode_base64(
msgpack.dumps( msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()} {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
) )
) )
def copy_and_replace(self, **kwds): def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds) return self._replace(**kwds)
def _matches_room_entry(room_entry, search_filter): def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None): if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper() generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper(): if generic_search_term in room_entry.get("name", "").upper():

View file

@ -203,7 +203,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end # Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates. # up blocking profile updates.
if newly_joined: if newly_joined and ratelimit:
time_now_s = self.clock.time() time_now_s = self.clock.time()
( (
allowed, allowed,
@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
) )
block_invite = True block_invite = True
if not self.spam_checker.user_may_invite( if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id requester.user.to_string(), target.to_string(), room_id
): ):
logger.info("Blocking invite due to spam checker") logger.info("Blocking invite due to spam checker")
@ -488,17 +488,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
time_now_s = self.clock.time() if ratelimit:
( time_now_s = self.clock.time()
allowed, (
time_allowed, allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,) time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(
if not allowed: requester,
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
) )
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
inviter = await self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)

View file

@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart, map_username_to_mxid_localpart,
mxid_localpart_allowed_characters, mxid_localpart_allowed_characters,
) )
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs) super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
@ -81,9 +78,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
def handle_redirect_request( def handle_redirect_request(
@ -167,6 +161,29 @@ class SamlHandler(BaseHandler):
return return
logger.debug("SAML2 response: %s", saml2_auth.origxml) logger.debug("SAML2 response: %s", saml2_auth.origxml)
await self._handle_authn_response(request, saml2_auth, relay_state)
async def _handle_authn_response(
self,
request: SynapseRequest,
saml2_auth: saml2.response.AuthnResponse,
relay_state: str,
) -> None:
"""Handle an AuthnResponse, having parsed it from the request params
Assumes that the signature on the response object has been checked. Maps
the user onto an MXID, registering them if necessary, and returns a response
to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
saml2_auth: the parsed AuthnResponse object
relay_state: the RelayState query param, which encodes the URI to rediret
back to
"""
for assertion in saml2_auth.assertions: for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather # kibana limits the length of a log field, whereas this is all rather
# useful, so split it up. # useful, so split it up.
@ -183,6 +200,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None saml2_auth.in_response_to, None
) )
# first check if we're doing a UIA
if current_session and current_session.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
except MappingException as e:
logger.exception("Failed to extract remote user id from SAML response")
self._sso_handler.render_error(request, "mapping_error", str(e))
return
return await self._sso_handler.complete_sso_ui_auth_request(
self._auth_provider_id,
remote_user_id,
current_session.ui_auth_session_id,
request,
)
# otherwise, we're handling a login request.
# Ensure that the attributes of the logged in user meet the required # Ensure that the attributes of the logged in user meet the required
# attributes. # attributes.
for requirement in self._saml2_attribute_requirements: for requirement in self._saml2_attribute_requirements:
@ -192,63 +227,39 @@ class SamlHandler(BaseHandler):
) )
return return
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
user_id = await self._map_saml_response_to_user( await self._complete_saml_login(saml2_auth, request, relay_state)
saml2_auth, relay_state, user_agent, ip_address
)
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login. async def _complete_saml_login(
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self, self,
saml2_auth: saml2.response.AuthnResponse, saml2_auth: saml2.response.AuthnResponse,
request: SynapseRequest,
client_redirect_url: str, client_redirect_url: str,
user_agent: str, ) -> None:
ip_address: str,
) -> str:
""" """
Given a SAML response, retrieve the user ID for it and possibly register the user. Given a SAML response, complete the login flow
Retrieves the remote user ID, registers the user if necessary, and serves
a redirect back to the client with a login-token.
Args: Args:
saml2_auth: The parsed SAML2 response. saml2_auth: The parsed SAML2 response.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client. client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
Raises: Raises:
MappingException if there was a problem mapping the response to a user. MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page. to redirect to an interstitial page.
""" """
remote_user_id = self._remote_id_from_saml_response(
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url saml2_auth, client_redirect_url
) )
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
)
async def saml_response_to_remapped_user_attributes( async def saml_response_to_remapped_user_attributes(
failures: int, failures: int,
) -> UserAttributes: ) -> UserAttributes:
@ -294,16 +305,44 @@ class SamlHandler(BaseHandler):
return None return None
with (await self._mapping_lock.queue(self._auth_provider_id)): await self._sso_handler.complete_sso_login_request(
return await self._sso_handler.get_mxid_from_sso( self._auth_provider_id,
self._auth_provider_id, remote_user_id,
remote_user_id, request,
user_agent, client_redirect_url,
ip_address, saml_response_to_remapped_user_attributes,
saml_response_to_remapped_user_attributes, grandfather_existing_users,
grandfather_existing_users, )
def _remote_id_from_saml_response(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: Optional[str],
) -> str:
"""Extract the unique remote id from a SAML2 AuthnResponse
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
Returns:
remote user id
Raises:
MappingException if there was an error extracting the user id
"""
# It's not obvious why we need to pass in the redirect URI to the mapping
# provider, but we do :/
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise MappingException(
"Failed to extract remote user id from SAML response"
) )
return remote_user_id
def expire_sessions(self): def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set() to_expire = set()

View file

@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
import attr import attr
from typing_extensions import NoReturn
from synapse.api.errors import RedirectException from twisted.web.http import Request
from synapse.handlers._base import BaseHandler
from synapse.api.errors import RedirectException, SynapseError
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
from synapse.util.stringutils import random_string
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -37,22 +42,70 @@ class MappingException(Exception):
@attr.s @attr.s
class UserAttributes: class UserAttributes:
localpart = attr.ib(type=str) # the localpart of the mxid that the mapper has assigned to the user.
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.
localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None) display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list)) emails = attr.ib(type=List[str], default=attr.Factory(list))
class SsoHandler(BaseHandler): @attr.s(slots=True)
class UsernameMappingSession:
"""Data we track about SSO sessions"""
# A unique identifier for this SSO provider, e.g. "oidc" or "saml".
auth_provider_id = attr.ib(type=str)
# user ID on the IdP server
remote_user_id = attr.ib(type=str)
# attributes returned by the ID mapper
display_name = attr.ib(type=Optional[str])
emails = attr.ib(type=List[str])
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
extra_login_attributes = attr.ib(type=Optional[JsonDict])
# where to redirect the client back to
client_redirect_url = attr.ib(type=str)
# expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int)
# the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID. # The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000 _MAP_USERNAME_RETRIES = 1000
# the time a UsernameMappingSession remains valid for
_MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler()
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
# a map from session id to session data
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
def render_error( def render_error(
self, request, error: str, error_description: Optional[str] = None self,
request: Request,
error: str,
error_description: Optional[str] = None,
code: int = 400,
) -> None: ) -> None:
"""Renders the error template and responds with it. """Renders the error template and responds with it.
@ -64,11 +117,12 @@ class SsoHandler(BaseHandler):
We'll respond with an HTML page describing the error. We'll respond with an HTML page describing the error.
error: A technical identifier for this error. error: A technical identifier for this error.
error_description: A human-readable description of the error. error_description: A human-readable description of the error.
code: The integer error code (an HTTP response code)
""" """
html = self._error_template.render( html = self._error_template.render(
error=error, error_description=error_description error=error, error_description=error_description
) )
respond_with_html(request, 400, html) respond_with_html(request, code, html)
async def get_sso_user_by_remote_user_id( async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str self, auth_provider_id: str, remote_user_id: str
@ -95,7 +149,7 @@ class SsoHandler(BaseHandler):
) )
# Check if we already have a mapping for this user. # Check if we already have a mapping for this user.
previously_registered_user_id = await self.store.get_user_by_external_id( previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id, auth_provider_id, remote_user_id,
) )
@ -112,15 +166,16 @@ class SsoHandler(BaseHandler):
# No match. # No match.
return None return None
async def get_mxid_from_sso( async def complete_sso_login_request(
self, self,
auth_provider_id: str, auth_provider_id: str,
remote_user_id: str, remote_user_id: str,
user_agent: str, request: SynapseRequest,
ip_address: str, client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
) -> str: extra_login_attributes: Optional[JsonDict] = None,
) -> None:
""" """
Given an SSO ID, retrieve the user ID for it and possibly register the user. Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -139,12 +194,18 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls. ID for subsequent calls.
Finally, we generate a redirect to the supplied redirect uri, with a login token
Args: Args:
auth_provider_id: A unique identifier for this SSO provider, e.g. auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml". "oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider. remote_user_id: The unique identifier from the SSO provider.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request. request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
sso_to_matrix_id_mapper: A callable to generate the user attributes. sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed. times the returned mxid localpart mapping has failed.
@ -156,12 +217,13 @@ class SsoHandler(BaseHandler):
to the user. to the user.
RedirectException to redirect to an additional page (e.g. RedirectException to redirect to an additional page (e.g.
to prompt the user for more information). to prompt the user for more information).
grandfather_existing_users: A callable which can return an previously grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned existing matrix ID. The SSO ID is then linked to the returned
matrix ID. matrix ID.
Returns: extra_login_attributes: An optional dictionary of extra
The user ID associated with the SSO response. attributes to be provided to the client in the login response.
Raises: Raises:
MappingException if there was a problem mapping the response to a user. MappingException if there was a problem mapping the response to a user.
@ -169,24 +231,55 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information) to an additional page. (e.g. to prompt for more information)
""" """
# first of all, check if we already have a mapping for this user # grab a lock while we try to find a mapping for this user. This seems...
previously_registered_user_id = await self.get_sso_user_by_remote_user_id( # optimistic, especially for implementations that end up redirecting to
auth_provider_id, remote_user_id, # interstitial pages.
) with await self._mapping_lock.queue(auth_provider_id):
if previously_registered_user_id: # first of all, check if we already have a mapping for this user
return previously_registered_user_id user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
# Check for grandfathering of users. # Check for grandfathering of users.
if grandfather_existing_users: if not user_id:
previously_registered_user_id = await grandfather_existing_users() user_id = await grandfather_existing_users()
if previously_registered_user_id: if user_id:
# Future logins should also match this user ID. # Future logins should also match this user ID.
await self.store.record_user_external_id( await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id auth_provider_id, remote_user_id, user_id
)
# Otherwise, generate a new user.
if not user_id:
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
if attributes.localpart is None:
# the mapper doesn't return a username. bail out with a redirect to
# the username picker.
await self._redirect_to_username_picker(
auth_provider_id,
remote_user_id,
attributes,
client_redirect_url,
extra_login_attributes,
)
user_id = await self._register_mapped_user(
attributes,
auth_provider_id,
remote_user_id,
request.get_user_agent(""),
request.getClientIP(),
) )
return previously_registered_user_id
# Otherwise, generate a new user. await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_login_attributes
)
async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES): for i in range(self._MAP_USERNAME_RETRIES):
try: try:
attributes = await sso_to_matrix_id_mapper(i) attributes = await sso_to_matrix_id_mapper(i)
@ -208,14 +301,12 @@ class SsoHandler(BaseHandler):
) )
if not attributes.localpart: if not attributes.localpart:
raise MappingException( # the mapper has not picked a localpart
"Error parsing SSO response: SSO mapping provider plugin " return attributes
"did not return a localpart value"
)
# Check if this mxid already exists # Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string() user_id = UserID(attributes.localpart, self._server_name).to_string()
if not await self.store.get_users_by_id_case_insensitive(user_id): if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free # This mxid is free
break break
else: else:
@ -224,10 +315,101 @@ class SsoHandler(BaseHandler):
raise MappingException( raise MappingException(
"Unable to generate a Matrix ID from the SSO response" "Unable to generate a Matrix ID from the SSO response"
) )
return attributes
async def _redirect_to_username_picker(
self,
auth_provider_id: str,
remote_user_id: str,
attributes: UserAttributes,
client_redirect_url: str,
extra_login_attributes: Optional[JsonDict],
) -> NoReturn:
"""Creates a UsernameMappingSession and redirects the browser
Called if the user mapping provider doesn't return a localpart for a new user.
Raises a RedirectException which redirects the browser to the username picker.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
attributes: the user attributes returned by the user mapping provider.
client_redirect_url: The redirect URL passed in by the client, which we
will eventually redirect back to.
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
Raises:
RedirectException
"""
session_id = random_string(16)
now = self._clock.time_msec()
session = UsernameMappingSession(
auth_provider_id=auth_provider_id,
remote_user_id=remote_user_id,
display_name=attributes.display_name,
emails=attributes.emails,
client_redirect_url=client_redirect_url,
expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
extra_login_attributes=extra_login_attributes,
)
self._username_mapping_sessions[session_id] = session
logger.info("Recorded registration session id %s", session_id)
# Set the cookie and redirect to the username picker
e = RedirectException(b"/_synapse/client/pick_username")
e.cookies.append(
b"%s=%s; path=/"
% (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
)
raise e
async def _register_mapped_user(
self,
attributes: UserAttributes,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
) -> str:
"""Register a new SSO user.
This is called once we have successfully mapped the remote user id onto a local
user id, one way or another.
Args:
attributes: user attributes returned by the user mapping provider,
including a non-empty localpart.
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
user_agent: The user-agent in the HTTP request (used for potential
shadow-banning.)
ip_address: The IP address of the requester (used for potential
shadow-banning.)
Raises:
a MappingException if the localpart is invalid.
a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
is already taken.
"""
# Since the localpart is provided via a potentially untrusted module, # Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering. # ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart): if not attributes.localpart or contains_invalid_mxid_characters(
attributes.localpart
):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart) logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@ -238,7 +420,152 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)], user_agent_ips=[(user_agent, ip_address)],
) )
await self.store.record_user_external_id( await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id auth_provider_id, remote_user_id, registered_user_id
) )
return registered_user_id return registered_user_id
async def complete_sso_ui_auth_request(
self,
auth_provider_id: str,
remote_user_id: str,
ui_auth_session_id: str,
request: Request,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and complete UIA.
Note that this requires that the user is mapped in the "user_external_ids"
table. This will be the case if they have ever logged in via SAML or OIDC in
recentish synapse versions, but may not be for older users.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
ui_auth_session_id: The ID of the user-interactive auth session.
request: The request to complete.
"""
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
if not user_id:
logger.warning(
"Remote user %s/%s has not previously logged in here: UIA will fail",
auth_provider_id,
remote_user_id,
)
# Let the UIA flow handle this the same as if they presented creds for a
# different user.
user_id = ""
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
async def check_username_availability(
self, localpart: str, session_id: str,
) -> bool:
"""Handle an "is username available" callback check
Args:
localpart: desired localpart
session_id: the session id for the username picker
Returns:
True if the username is available
Raises:
SynapseError if the localpart is invalid or the session is unknown
"""
# make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
logger.info(
"[session %s] Checking for availability of username %s",
session_id,
localpart,
)
if contains_invalid_mxid_characters(localpart):
raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
user_id = UserID(localpart, self._server_name).to_string()
user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
logger.info("[session %s] users: %s", session_id, user_infos)
return not user_infos
async def handle_submit_username_request(
self, request: SynapseRequest, localpart: str, session_id: str
) -> None:
"""Handle a request to the username-picker 'submit' endpoint
Will serve an HTTP response to the request.
Args:
request: HTTP request
localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
logger.info("[session %s] Registering localpart %s", session_id, localpart)
attributes = UserAttributes(
localpart=localpart,
display_name=session.display_name,
emails=session.emails,
)
# the following will raise a 400 error if the username has been taken in the
# meantime.
user_id = await self._register_mapped_user(
attributes,
session.auth_provider_id,
session.remote_user_id,
request.get_user_agent(""),
request.getClientIP(),
)
logger.info("[session %s] Registered userid %s", session_id, user_id)
# delete the mapping session and the cookie
del self._username_mapping_sessions[session_id]
# delete the cookie
request.addCookie(
USERNAME_MAPPING_SESSION_COOKIE_NAME,
b"",
expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
path=b"/",
)
await self._auth_handler.complete_sso_login(
user_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
)
def _expire_old_sessions(self):
to_expire = []
now = int(self._clock.time_msec())
for session_id, session in self._username_mapping_sessions.items():
if session.expiry_time_ms <= now:
to_expire.append(session_id)
for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]

View file

@ -554,7 +554,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter
) )
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id state_ids[(event.type, event.state_key)] = event.event_id
return state_ids return state_ids

View file

@ -14,14 +14,19 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
be in the directory or not when necessary. be in the directory or not when necessary.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream # The current position in the current_state_delta stream
self.pos = None self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time # Guard to ensure we only process deltas one at a time
self._is_processing = False self._is_processing = False
@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory # we start populating the user directory
self.clock.call_later(0, self.notify_new_event) self.clock.call_later(0, self.notify_new_event)
async def search_users(self, user_id, search_term, limit): async def search_users(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -81,15 +88,15 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit) results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results. # Remove any spammy users from the results.
results["results"] = [ non_spammy_users = []
user for user in results["results"]:
for user in results["results"] if not await self.spam_checker.check_username_for_spam(user):
if not self.spam_checker.check_username_for_spam(user) non_spammy_users.append(user)
] results["results"] = non_spammy_users
return results return results
def notify_new_event(self): def notify_new_event(self) -> None:
"""Called when there may be more deltas to process """Called when there may be more deltas to process
""" """
if not self.update_user_directory: if not self.update_user_directory:
@ -107,27 +114,33 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = True self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process) run_as_background_process("user_directory.notify_new_event", process)
async def handle_local_profile_change(self, user_id, profile): async def handle_local_profile_change(
self, user_id: str, profile: ProfileInfo
) -> None:
"""Called to update index of our local user profiles when they change """Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in. irrespective of any rooms the user may be in.
""" """
# FIXME(#3714): We should probably do this in the same worker as all # FIXME(#3714): We should probably do this in the same worker as all
# the other changes. # the other changes.
is_support = await self.store.is_support_user(user_id)
# Support users are for diagnostics and should not appear in the user directory. # Support users are for diagnostics and should not appear in the user directory.
if not is_support: is_support = await self.store.is_support_user(user_id)
# When change profile information of deactivated user it should not appear in the user directory.
is_deactivated = await self.store.get_user_deactivated_status(user_id)
if not (is_support or is_deactivated):
await self.store.update_profile_in_user_dir( await self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
async def handle_user_deactivated(self, user_id): async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated """Called when a user ID is deactivated
""" """
# FIXME(#3714): We should probably do this in the same worker as all # FIXME(#3714): We should probably do this in the same worker as all
# the other changes. # the other changes.
await self.store.remove_from_user_dir(user_id) await self.store.remove_from_user_dir(user_id)
async def _unsafe_process(self): async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB # If self.pos is None then means we haven't fetched it from DB
if self.pos is None: if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos() self.pos = await self.store.get_user_directory_stream_pos()
@ -162,7 +175,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos) await self.store.update_user_directory_stream_pos(max_pos)
async def _handle_deltas(self, deltas): async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process """Called with the state deltas to process
""" """
for delta in deltas: for delta in deltas:
@ -232,16 +245,20 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Ignoring irrelevant type: %r", typ) logger.debug("Ignoring irrelevant type: %r", typ)
async def _handle_room_publicity_change( async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ self,
): room_id: str,
prev_event_id: Optional[str],
event_id: Optional[str],
typ: str,
) -> None:
"""Handle a room having potentially changed from/to world_readable/publicly """Handle a room having potentially changed from/to world_readable/publicly
joinable. joinable.
Args: Args:
room_id (str) room_id: The ID of the room which changed.
prev_event_id (str|None): The previous event before the state change prev_event_id: The previous event before the state change
event_id (str|None): The new event after the state change event_id: The new event after the state change
typ (str): Type of the event typ: Type of the event
""" """
logger.debug("Handling change for %s: %s", typ, room_id) logger.debug("Handling change for %s: %s", typ, room_id)
@ -250,7 +267,7 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_event_id, prev_event_id,
event_id, event_id,
key_name="history_visibility", key_name="history_visibility",
public_value="world_readable", public_value=HistoryVisibility.WORLD_READABLE,
) )
elif typ == EventTypes.JoinRules: elif typ == EventTypes.JoinRules:
change = await self._get_key_change( change = await self._get_key_change(
@ -299,12 +316,14 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id, profile in users_with_profile.items(): for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile) await self._handle_new_user(room_id, user_id, profile)
async def _handle_new_user(self, room_id, user_id, profile): async def _handle_new_user(
self, room_id: str, user_id: str, profile: ProfileInfo
) -> None:
"""Called when we might need to add user to directory """Called when we might need to add user to directory
Args: Args:
room_id (str): room_id that user joined or started being public room_id: The room ID that user joined or started being public
user_id (str) user_id
""" """
logger.debug("Adding new user to dir, %r", user_id) logger.debug("Adding new user to dir, %r", user_id)
@ -352,12 +371,12 @@ class UserDirectoryHandler(StateDeltasHandler):
if to_insert: if to_insert:
await self.store.add_users_who_share_private_room(room_id, to_insert) await self.store.add_users_who_share_private_room(room_id, to_insert)
async def _handle_remove_user(self, room_id, user_id): async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
"""Called when we might need to remove user from directory """Called when we might need to remove user from directory
Args: Args:
room_id (str): room_id that user left or stopped being public that room_id: The room ID that user left or stopped being public that
user_id (str) user_id
""" """
logger.debug("Removing user %r", user_id) logger.debug("Removing user %r", user_id)
@ -370,7 +389,13 @@ class UserDirectoryHandler(StateDeltasHandler):
if len(rooms_user_is_in) == 0: if len(rooms_user_is_in) == 0:
await self.store.remove_from_user_dir(user_id) await self.store.remove_from_user_dir(user_id)
async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): async def _handle_profile_change(
self,
user_id: str,
room_id: str,
prev_event_id: Optional[str],
event_id: Optional[str],
) -> None:
"""Check member event changes for any profile changes and update the """Check member event changes for any profile changes and update the
database if there are. database if there are.
""" """

View file

@ -125,7 +125,7 @@ def _make_scheduler(reactor):
return _scheduler return _scheduler
class IPBlacklistingResolver: class _IPBlacklistingResolver:
""" """
A proxy for reactor.nameResolver which only produces non-blacklisted IP A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview. addresses, preventing DNS rebinding attacks on URL preview.
@ -199,6 +199,35 @@ class IPBlacklistingResolver:
return r return r
@implementer(IReactorPluggableNameResolver)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
addresses, to prevent DNS rebinding.
"""
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
):
self._reactor = reactor
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
self._nameResolver = _IPBlacklistingResolver(
self._reactor, ip_whitelist, ip_blacklist
)
def __getattr__(self, attr: str) -> Any:
# Passthrough to the real reactor except for the DNS resolver.
if attr == "nameResolver":
return self._nameResolver
else:
return getattr(self._reactor, attr)
class BlacklistingAgentWrapper(Agent): class BlacklistingAgentWrapper(Agent):
""" """
An Agent wrapper which will prevent access to IP addresses being accessed An Agent wrapper which will prevent access to IP addresses being accessed
@ -293,22 +322,11 @@ class SimpleHttpClient:
self.user_agent = self.user_agent.encode("ascii") self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist: if self._ip_blacklist:
real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which # If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding. # filters out blacklisted IP addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver( self.reactor = BlacklistingReactorWrapper(
real_reactor, self._ip_whitelist, self._ip_blacklist hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
) )
@implementer(IReactorPluggableNameResolver)
class Reactor:
def __getattr__(_self, attr):
if attr == "nameResolver":
return nameResolver
else:
return getattr(real_reactor, attr)
self.reactor = Reactor()
else: else:
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
@ -703,11 +721,14 @@ class SimpleHttpClient:
try: try:
length = await make_deferred_yieldable( length = await make_deferred_yieldable(
readBodyToFile(response, output_stream, max_size) read_body_with_max_size(response, output_stream, max_size)
)
except BodyExceededMaxSize:
SynapseError(
502,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
) )
except SynapseError:
# This can happen e.g. because the body is too large.
raise
except Exception as e: except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
@ -731,7 +752,11 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f return f
class _ReadBodyToFileProtocol(protocol.Protocol): class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__( def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
): ):
@ -744,13 +769,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data) self.stream.write(data)
self.length += len(data) self.length += len(data)
if self.max_size is not None and self.length >= self.max_size: if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback( self.deferred.errback(BodyExceededMaxSize())
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred() self.deferred = defer.Deferred()
self.transport.loseConnection() self.transport.loseConnection()
@ -765,12 +784,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason) self.deferred.errback(reason)
def readBodyToFile( def read_body_with_max_size(
response: IResponse, stream: BinaryIO, max_size: Optional[int] response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred: ) -> defer.Deferred:
""" """
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
If the maximum file size is reached, the returned Deferred will resolve to a
Failure with a BodyExceededMaxSize exception.
Args: Args:
response: The HTTP response to read from. response: The HTTP response to read from.
stream: The file-object to write to. stream: The file-object to write to.
@ -781,7 +803,7 @@ def readBodyToFile(
""" """
d = defer.Deferred() d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d return d

View file

@ -16,7 +16,7 @@ import logging
import urllib.parse import urllib.parse
from typing import List, Optional from typing import List, Optional
from netaddr import AddrFormatError, IPAddress from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -70,6 +71,7 @@ class MatrixFederationAgent:
reactor: IReactorCore, reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS], tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes, user_agent: bytes,
ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None, _srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None,
): ):
@ -90,12 +92,18 @@ class MatrixFederationAgent:
self.user_agent = user_agent self.user_agent = user_agent
if _well_known_resolver is None: if _well_known_resolver is None:
# Note that the name resolver has already been wrapped in a
# IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver( _well_known_resolver = WellKnownResolver(
self._reactor, self._reactor,
agent=Agent( agent=BlacklistingAgentWrapper(
Agent(
self._reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
),
self._reactor, self._reactor,
pool=self._pool, ip_blacklist=ip_blacklist,
contextFactory=tls_client_options_factory,
), ),
user_agent=self.user_agent, user_agent=self.user_agent,
) )

View file

@ -15,17 +15,19 @@
import logging import logging
import random import random
import time import time
from io import BytesIO
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import attr import attr
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody from twisted.web.client import RedirectAgent
from twisted.web.http import stringToDatetime from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IResponse from twisted.web.iweb import IAgent, IResponse
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache from synapse.util.caches.ttlcache import TTLCache
@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period # lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
# The maximum size (in bytes) to allow a well-known file to be.
WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB
# Attempt to refetch a cached well-known N% of the TTL before it expires. # Attempt to refetch a cached well-known N% of the TTL before it expires.
# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then # e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
# we'll start trying to refetch 1 minute before it expires. # we'll start trying to refetch 1 minute before it expires.
@ -229,6 +234,9 @@ class WellKnownResolver:
server_name: name of the server, from the requested url server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails. retry: Whether to retry the request if it fails.
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns: Returns:
Returns the response object and body. Response may be a non-200 response. Returns the response object and body. Response may be a non-200 response.
""" """
@ -250,7 +258,11 @@ class WellKnownResolver:
b"GET", uri, headers=Headers(headers) b"GET", uri, headers=Headers(headers)
) )
) )
body = await make_deferred_yieldable(readBody(response)) body_stream = BytesIO()
await make_deferred_yieldable(
read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE)
)
body = body_stream.getvalue()
if 500 <= response.code < 600: if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,)) raise Exception("Non-200 response %s" % (response.code,))
@ -259,6 +271,15 @@ class WellKnownResolver:
except defer.CancelledError: except defer.CancelledError:
# Bail if we've been cancelled # Bail if we've been cancelled
raise raise
except BodyExceededMaxSize:
# If the well-known file was too large, do not keep attempting
# to download it, but consider it a temporary error.
logger.warning(
"Requested .well-known file for %s is too large > %r bytes",
server_name.decode("ascii"),
WELL_KNOWN_MAX_SIZE,
)
raise _FetchWellKnownFailure(temporary=True)
except Exception as e: except Exception as e:
if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS: if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
logger.info("Error fetching %s: %s", uri_str, e) logger.info("Error fetching %s: %s", uri_str, e)

View file

@ -26,11 +26,10 @@ import treq
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from prometheus_client import Counter from prometheus_client import Counter
from signedjson.sign import sign_json from signedjson.sign import sign_json
from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator from twisted.internet.task import _EPSILON, Cooperator
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse from twisted.web.iweb import IBodyProducer, IResponse
@ -38,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics import synapse.metrics
import synapse.util.retryutils import synapse.util.retryutils
from synapse.api.errors import ( from synapse.api.errors import (
Codes,
FederationDeniedError, FederationDeniedError,
HttpResponseException, HttpResponseException,
RequestSendFailed, RequestSendFailed,
SynapseError,
) )
from synapse.http import QuieterFileBodyProducer from synapse.http import QuieterFileBodyProducer
from synapse.http.client import ( from synapse.http.client import (
BlacklistingAgentWrapper, BlacklistingAgentWrapper,
IPBlacklistingResolver, BlacklistingReactorWrapper,
BodyExceededMaxSize,
encode_query_args, encode_query_args,
readBodyToFile, read_body_with_max_size,
) )
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -221,31 +223,22 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
self.server_name = hs.hostname self.server_name = hs.hostname
real_reactor = hs.get_reactor()
# We need to use a DNS resolver which filters out blacklisted IP # We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding. # addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver( self.reactor = BlacklistingReactorWrapper(
real_reactor, None, hs.config.federation_ip_range_blacklist hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
) )
@implementer(IReactorPluggableNameResolver)
class Reactor:
def __getattr__(_self, attr):
if attr == "nameResolver":
return nameResolver
else:
return getattr(real_reactor, attr)
self.reactor = Reactor()
user_agent = hs.version_string user_agent = hs.version_string
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii") user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent( self.agent = MatrixFederationAgent(
self.reactor, tls_client_options_factory, user_agent self.reactor,
tls_client_options_factory,
user_agent,
hs.config.federation_ip_range_blacklist,
) )
# Use a BlacklistingAgentWrapper to prevent circumventing the IP # Use a BlacklistingAgentWrapper to prevent circumventing the IP
@ -985,9 +978,15 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
d = readBodyToFile(response, output_stream, max_size) d = read_body_with_max_size(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor) d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d) length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,)
logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg,
)
SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response: %s",

View file

@ -275,6 +275,10 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON. formatting responses and errors as JSON.
""" """
def __init__(self, canonical_json=False, extract_context=False):
super().__init__(extract_context)
self.canonical_json = canonical_json
def _send_response( def _send_response(
self, request: Request, code: int, response_object: Any, self, request: Request, code: int, response_object: Any,
): ):
@ -318,9 +322,7 @@ class JsonResource(DirectServeJsonResource):
) )
def __init__(self, hs, canonical_json=True, extract_context=False): def __init__(self, hs, canonical_json=True, extract_context=False):
super().__init__(extract_context) super().__init__(canonical_json, extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.hs = hs self.hs = hs

View file

@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request # create a LogContext for this request
request_id = self.get_request_id() request_id = self.get_request_id()
logcontext = self.logcontext = LoggingContext(request_id) self.logcontext = LoggingContext(request_id, request=request_id)
logcontext.request = request_id
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)

View file

@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record): def copy_to(self, record):
pass pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self, rusage: "Optional[resource._RUsage]"): def start(self, rusage: "Optional[resource._RUsage]"):
pass pass
@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope: # we also track the current scope:
record.scope = self.scope record.scope = self.scope
def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope
def start(self, rusage: "Optional[resource._RUsage]") -> None: def start(self, rusage: "Optional[resource._RUsage]") -> None:
""" """
Record that this logcontext is currently running. Record that this logcontext is currently running.
@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter): class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each """Logging filter that adds values from the current logging context to each
record. record.
Args:
**defaults: Default values to avoid formatters complaining about
missing fields
""" """
def __init__(self, **defaults) -> None: def __init__(self, request: str = ""):
self.defaults = defaults self._default_request = request
def filter(self, record) -> Literal[True]: def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record. """Add each fields from the logging contexts to the record.
@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output. True to include the record in the log output.
""" """
context = current_context() context = current_context()
for key, value in self.defaults.items(): record.request = self._default_request
setattr(record, key, value)
# context should never be None, but if it somehow ends up being, then # context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for # we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake. # robustness' sake.
if context is not None: if context is not None:
context.copy_to(record) # Logging is interested in the request.
record.request = context.request
return True return True

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span from synapse.logging.opentracing import noop_context_manager, start_active_span
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import resource import resource
@ -199,19 +199,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc()
with BackgroundProcessLoggingContext(desc) as context: with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
context.request = "%s-%i" % (desc, count)
try: try:
ctx = noop_context_manager() ctx = noop_context_manager()
if bg_start_span: if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request}) ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx: with ctx:
result = func(*args, **kwargs) return await maybe_awaitable(func(*args, **kwargs))
if inspect.isawaitable(result):
result = await result
return result
except Exception: except Exception:
logger.exception( logger.exception(
"Background process '%s' threw an exception", desc, "Background process '%s' threw an exception", desc,
@ -249,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"] __slots__ = ["_proc"]
def __init__(self, name: str): def __init__(self, name: str, request: Optional[str] = None):
super().__init__(name) super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)

View file

@ -34,7 +34,7 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
import synapse.server import synapse.server
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -611,7 +611,9 @@ class Notifier:
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if state and "history_visibility" in state.content: if state and "history_visibility" in state.content:
return state.content["history_visibility"] == "world_readable" return (
state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE
)
else: else:
return False return False

View file

@ -13,7 +13,111 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional
import attr
from synapse.types import JsonDict, RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@attr.s(slots=True)
class PusherConfig:
"""Parameters necessary to configure a pusher."""
id = attr.ib(type=Optional[str])
user_name = attr.ib(type=str)
access_token = attr.ib(type=Optional[int])
profile_tag = attr.ib(type=str)
kind = attr.ib(type=str)
app_id = attr.ib(type=str)
app_display_name = attr.ib(type=str)
device_display_name = attr.ib(type=str)
pushkey = attr.ib(type=str)
ts = attr.ib(type=int)
lang = attr.ib(type=Optional[str])
data = attr.ib(type=Optional[JsonDict])
last_stream_ordering = attr.ib(type=int)
last_success = attr.ib(type=Optional[int])
failing_since = attr.ib(type=Optional[int])
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
return {
"app_display_name": self.app_display_name,
"app_id": self.app_id,
"data": self.data,
"device_display_name": self.device_display_name,
"kind": self.kind,
"lang": self.lang,
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
}
@attr.s(slots=True)
class ThrottleParams:
"""Parameters for controlling the rate of sending pushes via email."""
last_sent_ts = attr.ib(type=int)
throttle_ms = attr.ib(type=int)
class Pusher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusher_config.id
self.user_id = pusher_config.user_name
self.app_id = pusher_config.app_id
self.pushkey = pusher_config.pushkey
self.last_stream_ordering = pusher_config.last_stream_ordering
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation.
self.max_stream_ordering = self.store.get_room_max_stream_ordering()
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
self._start_processing()
@abc.abstractmethod
def _start_processing(self):
"""Start processing push notifications."""
raise NotImplementedError()
@abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def on_started(self, have_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
raise NotImplementedError()
@abc.abstractmethod
def on_stop(self) -> None:
raise NotImplementedError()
class PusherConfigException(Exception): class PusherConfigException(Exception):
def __init__(self, msg): """An error occurred when creating a pusher."""
super().__init__(msg)

View file

@ -14,19 +14,22 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs) self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile # event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
async def handle_push_actions_for_event(self, event, context): async def handle_push_actions_for_event(
self, event: EventBase, context: EventContext
) -> None:
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context) await self.bulk_evaluator.action_for_event_by_user(event, context)

View file

@ -15,16 +15,19 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
def list_with_base_rules(rawrules, use_new_defaults=False): def list_with_base_rules(
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules """Combine the list of rules set by the user with the default push rules
Args: Args:
rawrules(list): The rules the user has modified or set. rawrules: The rules the user has modified or set.
use_new_defaults(bool): Whether to use the new experimental default rules when use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules. appending or prepending default rules.
Returns: Returns:
@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist return ruleslist
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): def make_base_append_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = [] rules = []
if kind == "override": if kind == "override":
@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules) rules = copy.deepcopy(rules)
for r in rules: for r in rules:
# Only modify the actions, keep the conditions the same. # Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"]) modified = modified_base_rules.get(r["rule_id"])
if modified: if modified:
r["actions"] = modified["actions"] r["actions"] = modified["actions"]
@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules return rules
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): def make_base_prepend_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = [] rules = []
if kind == "override": if kind == "override":
@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules) rules = copy.deepcopy(rules)
for r in rules: for r in rules:
# Only modify the actions, keep the conditions the same. # Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"]) modified = modified_base_rules.get(r["rule_id"])
if modified: if modified:
r["actions"] = modified["actions"] r["actions"] = modified["actions"]

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -25,18 +26,18 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rules_by_room = {}
push_rules_invalidation_counter = Counter( push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
) )
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once. room at once.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False, resizable=False,
) )
async def _get_rules_for_event(self, event, context): async def _get_rules_for_event(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event, """This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite. as well as the push rules for the invitee if the event is an invite.
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user return rules_by_user
@lru_cache() @lru_cache()
def _get_rules_for_room(self, room_id): def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id """Get the current RulesForRoom object for the given room id
Returns:
RulesForRoom
""" """
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be # before any lookup methods get called on it as otherwise there may be
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics, self.room_push_rule_cache_metrics,
) )
async def _get_power_levels_and_sender_level(self, event, context): async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY) pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id: if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and # fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case # not having a power level event is an extreme edge case
pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_dict = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None: async def action_for_event_by_user(
self, event: EventBase, context: EventContext
) -> None:
"""Given an event and context, evaluate the push rules, check if the message """Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the should increment the unread count, and insert the results into the
event_push_actions_staging table. event_push_actions_staging table.
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context) count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context) room_members = await self.store.get_joined_users_from_context(event, context)
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels event, len(room_members), sender_power_level, power_levels
) )
condition_cache = {} condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items(): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
) )
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
display_name: str,
cache: Dict[str, bool],
) -> bool:
for cond in conditions: for cond in conditions:
_id = cond.get("_id", None) _id = cond.get("_id", None)
if _id: if _id:
@ -277,15 +286,19 @@ class RulesForRoom:
""" """
def __init__( def __init__(
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics self,
hs: "HomeServer",
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
): ):
""" """
Args: Args:
hs (HomeServer) hs: The HomeServer object.
room_id (str) room_id: The room ID.
rules_for_room_cache: The cache object that caches these rules_for_room_cache: The cache object that caches these
RoomsForUser objects. RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric) room_push_rule_cache_metrics: The metrics object
""" """
self.room_id = room_id self.room_id = room_id
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room") self.linearizer = Linearizer(name="rules_for_room")
self.member_map = {} # event_id -> (user_id, state) # event_id -> (user_id, state)
self.rules_by_user = {} # user_id -> rules self.member_map = {} # type: Dict[str, Tuple[str, str]]
# user_id -> rules
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of # The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached # a new event comes along, we know that we can just return the cached
@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for) # calculate push for)
# These never need to be invalidated as we will never set up push for # These never need to be invalidated as we will never set up push for
# them. # them.
self.uninteresting_user_set = set() self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as # We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object, # otherwise the invalidation callback holds a reference to the object,
@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback. # to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
async def get_rules(self, event, context): async def get_rules(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are """Given an event context return the rules for all users who are
currently in the room. currently in the room.
""" """
@ -356,6 +373,8 @@ class RulesForRoom:
else: else:
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))
@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user return ret_rules_by_user
async def _update_rules_with_member_event_ids( async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event self,
): ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for """Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list. any newly joined users in the `member_event_ids` list.
Args: Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules. updated with any new rules.
member_event_ids (dict): Dict of user id to event id for membership events member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules state_group: The state group we are currently computing push rules
for. Used when updating the cache. for. Used when updating the cache.
event: The event we are currently computing push rules for.
""" """
sequence = self.sequence sequence = self.sequence
@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values()) logger.debug("Found members %r: %r", self.room_id, members.values())
user_ids = { joined_user_ids = {
user_id user_id
for user_id, membership in members.values() for user_id, membership in members.values()
if membership == Membership.JOIN if membership == Membership.JOIN
} }
logger.debug("Joined: %r", user_ids) logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that # Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread # room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into # counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here. # the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, user_ids)) user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules( rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb user_ids, on_invalidate=self.invalidate_all_cb
@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group) self.update_cache(sequence, members, ret_rules_by_user, state_group)
def invalidate_all(self): def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback # Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being # as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use # GC'd if it gets dropped from the rules_to_user cache. Instead use
@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {} self.rules_by_user = {}
push_rules_invalidation_counter.inc() push_rules_invalidation_counter.inc()
def update_cache(self, sequence, members, rules_by_user, state_group): def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence: if sequence == self.sequence:
self.member_map.update(members) self.member_map.update(members)
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache) cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str) room_id = attr.ib(type=str)
def __call__(self): def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False) rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules: if rules:
rules.invalidate_all() rules.invalidate_all()

View file

@ -14,24 +14,27 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.types import UserID
def format_push_rules_for_user(user, ruleslist): def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries """Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules""" to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist) ruleslist = copy.deepcopy(ruleslist)
rules = {"global": {}, "device": {}} rules = {
"global": {},
"device": {},
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"]) rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist: for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r["priority_class"]) template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff. # Remove internal stuff.
@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules return rules
def _add_empty_priority_class_arrays(d): def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys(): for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = [] d[pc] = []
return d return d
def _rule_to_template(rule): def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None unscoped_rule_id = None
if "rule_id" in rule: if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None return None
templaterule = {"actions": rule["actions"]} templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"] templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
# with PRIORITY_CLASS_INVERSE_MAP.
raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id: if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id templaterule["rule_id"] = unscoped_rule_id
@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule return templaterule
def _rule_id_from_namespaced(in_rule_id): def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1] return in_rule_id.split("/")[-1]
def _priority_class_to_template_name(pc): def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc] return PRIORITY_CLASS_INVERSE_MAP[pc]

View file

@ -14,11 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import RoomStreamToken from synapse.push import Pusher, PusherConfig, ThrottleParams
from synapse.push.mailer import Mailer
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +52,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher: class EmailPusher(Pusher):
""" """
A pusher that sends email notifications about events (approximately) A pusher that sends email notifications about events (approximately)
when they happen. when they happen.
@ -54,37 +60,30 @@ class EmailPusher:
factor out the common parts factor out the common parts
""" """
def __init__(self, hs, pusherdict, mailer): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
self.hs = hs super().__init__(hs, pusher_config)
self.mailer = mailer self.mailer = mailer
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.email = pusher_config.pushkey
self.pusher_id = pusherdict["id"] self.timed_call = None # type: Optional[DelayedCall]
self.user_id = pusherdict["user_name"] self.throttle_params = {} # type: Dict[str, ThrottleParams]
self.app_id = pusherdict["app_id"] self._inited = False
self.email = pusherdict["pushkey"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None
self.throttle_params = None
# See httppusher
self.max_stream_ordering = None
self._is_processing = False self._is_processing = False
def on_started(self, should_check_for_notifs): def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started. """Called when this pusher has been started.
Args: Args:
should_check_for_notifs (bool): Whether we should immediately should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there check for push to send. Set to False only if it's known there
is nothing to send is nothing to send
""" """
if should_check_for_notifs and self.mailer is not None: if should_check_for_notifs and self.mailer is not None:
self._start_processing() self._start_processing()
def on_stop(self): def on_stop(self) -> None:
if self.timed_call: if self.timed_call:
try: try:
self.timed_call.cancel() self.timed_call.cancel()
@ -92,37 +91,23 @@ class EmailPusher:
pass pass
self.timed_call = None self.timed_call = None
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
if self.max_stream_ordering:
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering
)
else:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
# We could wake up and cancel the timer but there tend to be quite a # We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the # lot of read receipts so it's probably less work to just let the
# timer fire # timer fire
pass pass
def on_timer(self): def on_timer(self) -> None:
self.timed_call = None self.timed_call = None
self._start_processing() self._start_processing()
def _start_processing(self): def _start_processing(self) -> None:
if self._is_processing: if self._is_processing:
return return
run_as_background_process("emailpush.process", self._process) run_as_background_process("emailpush.process", self._process)
def _pause_processing(self): def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events. """Used by tests to temporarily pause processing of events.
Asserts that its not currently processing. Asserts that its not currently processing.
@ -130,25 +115,27 @@ class EmailPusher:
assert not self._is_processing assert not self._is_processing
self._is_processing = True self._is_processing = True
def _resume_processing(self): def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing. """Used by tests to resume processing of events after pausing.
""" """
assert self._is_processing assert self._is_processing
self._is_processing = False self._is_processing = False
self._start_processing() self._start_processing()
async def _process(self): async def _process(self) -> None:
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
try: try:
self._is_processing = True self._is_processing = True
if self.throttle_params is None: if not self._inited:
# this is our first loop: load up the throttle params # this is our first loop: load up the throttle params
assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room( self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id self.pusher_id
) )
self._inited = True
# if the max ordering changes while we're running _unsafe_process, # if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up. # call it again, and so on until we've caught up.
@ -163,17 +150,18 @@ class EmailPusher:
finally: finally:
self._is_processing = False self._is_processing = False
async def _unsafe_process(self): async def _unsafe_process(self) -> None:
""" """
Main logic of the push loop without the wrapper function that sets Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it up logging, measures and guards against multiple instances of it
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
unprocessed = await fn(self.user_id, start, self.max_stream_ordering) self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None soonest_due_at = None # type: Optional[int]
if not unprocessed: if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering) await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@ -230,11 +218,9 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer self.seconds_until(soonest_due_at), self.on_timer
) )
async def save_last_stream_ordering_and_success(self, last_stream_ordering): async def save_last_stream_ordering_and_success(
if last_stream_ordering is None: self, last_stream_ordering: int
# This happens if we haven't yet processed anything ) -> None:
return
self.last_stream_ordering = last_stream_ordering self.last_stream_ordering = last_stream_ordering
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
@ -248,28 +234,30 @@ class EmailPusher:
# lets just stop and return. # lets just stop and return.
self.on_stop() self.on_stop()
def seconds_until(self, ts_msec): def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000 secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0) return max(secs, 0)
def get_room_throttle_ms(self, room_id): def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"] return self.throttle_params[room_id].throttle_ms
else: else:
return 0 return 0
def get_room_last_sent_ts(self, room_id): def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"] return self.throttle_params[room_id].last_sent_ts
else: else:
return 0 return 0
def room_ready_to_notify_at(self, room_id): def room_ready_to_notify_at(self, room_id: str) -> int:
""" """
Determines whether throttling should prevent us from sending an email Determines whether throttling should prevent us from sending an email
for the given room for the given room
Returns: The timestamp when we are next allowed to send an email notif
for this room Returns:
The timestamp when we are next allowed to send an email notif
for this room
""" """
last_sent_ts = self.get_room_last_sent_ts(room_id) last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id) throttle_ms = self.get_room_throttle_ms(room_id)
@ -277,7 +265,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms may_send_at = last_sent_ts + throttle_ms
return may_send_at return may_send_at
async def sent_notif_update_throttle(self, room_id, notified_push_action): async def sent_notif_update_throttle(
self, room_id: str, notified_push_action: dict
) -> None:
# We have sent a notification, so update the throttle accordingly. # We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than # If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@ -307,15 +297,15 @@ class EmailPusher:
new_throttle_ms = min( new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
) )
self.throttle_params[room_id] = { self.throttle_params[room_id] = ThrottleParams(
"last_sent_ts": self.clock.time_msec(), self.clock.time_msec(), new_throttle_ms,
"throttle_ms": new_throttle_ms, )
} assert self.pusher_id is not None
await self.store.set_throttle_params( await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id] self.pusher_id, room_id, self.throttle_params[room_id]
) )
async def send_notification(self, push_actions, reason): async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id) logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail( await self.mailer.send_notification_mail(

View file

@ -14,19 +14,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
http_push_processed_counter = Counter( http_push_processed_counter = Counter(
@ -50,91 +55,76 @@ http_badges_failed_counter = Counter(
) )
class HttpPusher: class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60 MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock # This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs, pusherdict): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs super().__init__(hs, pusher_config)
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock() self.app_display_name = pusher_config.app_display_name
self.state_handler = self.hs.get_state_handler() self.device_display_name = pusher_config.device_display_name
self.user_id = pusherdict["user_name"] self.pushkey_ts = pusher_config.ts
self.app_id = pusherdict["app_id"] self.data = pusher_config.data
self.app_display_name = pusherdict["app_display_name"]
self.device_display_name = pusherdict["device_display_name"]
self.pushkey = pusherdict["pushkey"]
self.pushkey_ts = pusherdict["ts"]
self.data = pusherdict["data"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict["failing_since"] self.failing_since = pusher_config.failing_since
self.timed_call = None self.timed_call = None
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process. self.data = pusher_config.data
# When new events arrive, we'll be given a window of new events: we if self.data is None:
# should honour this rather than just looking for anything higher raise PusherConfigException("'data' key can not be null for HTTP pusher")
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None
if "data" not in pusherdict:
raise PusherConfigException("No 'data' key for HTTP pusher")
self.data = pusherdict["data"]
self.name = "%s/%s/%s" % ( self.name = "%s/%s/%s" % (
pusherdict["user_name"], pusher_config.user_name,
pusherdict["app_id"], pusher_config.app_id,
pusherdict["pushkey"], pusher_config.pushkey,
) )
if self.data is None: # Validate that there's a URL and it is of the proper form.
raise PusherConfigException("data can not be null for HTTP pusher")
if "url" not in self.data: if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher") raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
self.http_client = hs.get_proxied_http_client() url = self.data["url"]
if not isinstance(url, str):
raise PusherConfigException("'url' must be a string")
url_parts = urllib.parse.urlparse(url)
# Note that the specification also says the scheme must be HTTPS, but
# it isn't up to the homeserver to verify that.
if url_parts.path != "/_matrix/push/v1/notify":
raise PusherConfigException(
"'url' must have a path of '/_matrix/push/v1/notify'"
)
self.url = url
self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {} self.data_minus_url = {}
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url["url"] del self.data_minus_url["url"]
def on_started(self, should_check_for_notifs): def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started. """Called when this pusher has been started.
Args: Args:
should_check_for_notifs (bool): Whether we should immediately should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there check for push to send. Set to False only if it's known there
is nothing to send is nothing to send
""" """
if should_check_for_notifs: if should_check_for_notifs:
self._start_processing() self._start_processing()
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
max_stream_ordering = max_token.stream
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering or 0
)
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here, # We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge) run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
async def _update_badge(self): async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it. # to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count( badge = await push_tools.get_badge_count(
@ -144,10 +134,10 @@ class HttpPusher:
) )
await self._send_badge(badge) await self._send_badge(badge)
def on_timer(self): def on_timer(self) -> None:
self._start_processing() self._start_processing()
def on_stop(self): def on_stop(self) -> None:
if self.timed_call: if self.timed_call:
try: try:
self.timed_call.cancel() self.timed_call.cancel()
@ -155,13 +145,13 @@ class HttpPusher:
pass pass
self.timed_call = None self.timed_call = None
def _start_processing(self): def _start_processing(self) -> None:
if self._is_processing: if self._is_processing:
return return
run_as_background_process("httppush.process", self._process) run_as_background_process("httppush.process", self._process)
async def _process(self): async def _process(self) -> None:
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
@ -180,15 +170,13 @@ class HttpPusher:
finally: finally:
self._is_processing = False self._is_processing = False
async def _unsafe_process(self): async def _unsafe_process(self) -> None:
""" """
Looks for unset notifications and dispatch them, in order Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
run once per pusher. run once per pusher.
""" """
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -257,17 +245,12 @@ class HttpPusher:
) )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = await self.store.update_pusher_last_stream_ordering( await self.store.update_pusher_last_stream_ordering(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
self.last_stream_ordering, self.last_stream_ordering,
) )
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
# lets just stop and return.
self.on_stop()
return
self.failing_since = None self.failing_since = None
await self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
@ -283,7 +266,7 @@ class HttpPusher:
) )
break break
async def _process_one(self, push_action): async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]: if "notify" not in push_action["actions"]:
return True return True
@ -314,7 +297,9 @@ class HttpPusher:
await self.hs.remove_pusher(self.app_id, pk, self.user_id) await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True return True
async def _build_notification_dict(self, event, tweaks, badge): async def _build_notification_dict(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Dict[str, Any]:
priority = "low" priority = "low"
if ( if (
event.type == EventTypes.Encrypted event.type == EventTypes.Encrypted
@ -325,6 +310,8 @@ class HttpPusher:
# or may do so (i.e. is encrypted so has unknown effects). # or may do so (i.e. is encrypted so has unknown effects).
priority = "high" priority = "high"
# This was checked in the __init__, but mypy doesn't seem to know that.
assert self.data is not None
if self.data.get("format") == "event_id_only": if self.data.get("format") == "event_id_only":
d = { d = {
"notification": { "notification": {
@ -344,9 +331,7 @@ class HttpPusher:
} }
return d return d
ctx = await push_tools.get_context_for_event( ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
self.storage, self.state_handler, event, self.user_id
)
d = { d = {
"notification": { "notification": {
@ -386,7 +371,9 @@ class HttpPusher:
return d return d
async def dispatch_push(self, event, tweaks, badge): async def dispatch_push(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge) notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict: if not notification_dict:
return [] return []

View file

@ -19,7 +19,7 @@ import logging
import urllib.parse import urllib.parse
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach import bleach
import jinja2 import jinja2
@ -27,16 +27,20 @@ import jinja2
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, calculate_room_name,
descriptor_from_member_events, descriptor_from_member_events,
name_from_member_event, name_from_member_event,
) )
from synapse.types import UserID from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
class Mailer: class Mailer:
def __init__(self, hs, app_name, template_html, template_text): def __init__(
self,
hs: "HomeServer",
app_name: str,
template_html: jinja2.Template,
template_text: jinja2.Template,
):
self.hs = hs self.hs = hs
self.template_html = template_html self.template_html = template_html
self.template_text = template_text self.template_text = template_text
@ -108,17 +118,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
async def send_password_reset_mail(self, email_address, token, client_secret, sid): async def send_password_reset_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a password reset link to a user """Send an email with a password reset link to a user
Args: Args:
email_address (str): Email address we're sending the password email_address: Email address we're sending the password
reset to reset to
token (str): Unique token generated by the server to verify token: Unique token generated by the server to verify
the email was received the email was received
client_secret (str): Unique token generated by the client to client_secret: Unique token generated by the client to
group together multiple email sending attempts group together multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -136,17 +148,19 @@ class Mailer:
template_vars, template_vars,
) )
async def send_registration_mail(self, email_address, token, client_secret, sid): async def send_registration_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a registration confirmation link to a user """Send an email with a registration confirmation link to a user
Args: Args:
email_address (str): Email address we're sending the registration email_address: Email address we're sending the registration
link to link to
token (str): Unique token generated by the server to verify token: Unique token generated by the server to verify
the email was received the email was received
client_secret (str): Unique token generated by the client to client_secret: Unique token generated by the client to
group together multiple email sending attempts group together multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -164,18 +178,20 @@ class Mailer:
template_vars, template_vars,
) )
async def send_add_threepid_mail(self, email_address, token, client_secret, sid): async def send_add_threepid_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account """Send an email with a validation link to a user for adding a 3pid to their account
Args: Args:
email_address (str): Email address we're sending the validation link to email_address: Email address we're sending the validation link to
token (str): Unique token generated by the server to verify the email was received token: Unique token generated by the server to verify the email was received
client_secret (str): Unique token generated by the client to group together client_secret: Unique token generated by the client to group together
multiple email sending attempts multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -194,8 +210,13 @@ class Mailer:
) )
async def send_notification_mail( async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason self,
): app_id: str,
user_id: str,
email_address: str,
push_actions: Iterable[Dict[str, Any]],
reason: Dict[str, Any],
) -> None:
"""Send email regarding a user's room notifications""" """Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
@ -203,7 +224,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions] [pa["event_id"] for pa in push_actions]
) )
notifs_by_room = {} notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions: for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa) notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@ -262,7 +283,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars) await self.send_email(email_address, summary_text, template_vars)
async def send_email(self, email_address, subject, extra_template_vars): async def send_email(
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
) -> None:
"""Send an email with the given information and template text""" """Send an email with the given information and template text"""
try: try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name} from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@ -315,8 +338,13 @@ class Mailer:
) )
async def get_room_vars( async def get_room_vars(
self, room_id, user_id, notifs, notif_events, room_state_ids self,
): room_id: str,
user_id: str,
notifs: Iterable[Dict[str, Any]],
notif_events: Dict[str, EventBase],
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user. # Check if one of the notifs is an invite event for the user.
is_invite = False is_invite = False
for n in notifs: for n in notifs:
@ -334,7 +362,7 @@ class Mailer:
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
"link": self.make_room_link(room_id), "link": self.make_room_link(room_id),
} } # type: Dict[str, Any]
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
@ -365,7 +393,13 @@ class Mailer:
return room_vars return room_vars
async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): async def get_notif_vars(
self,
notif: Dict[str, Any],
user_id: str,
notif_event: EventBase,
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
results = await self.store.get_events_around( results = await self.store.get_events_around(
notif["room_id"], notif["room_id"],
notif["event_id"], notif["event_id"],
@ -391,7 +425,9 @@ class Mailer:
return ret return ret
async def get_message_vars(self, notif, event, room_state_ids): async def get_message_vars(
self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None return None
@ -432,7 +468,9 @@ class Mailer:
return ret return ret
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
msgformat = event.content.get("format") msgformat = event.content.get("format")
messagevars["format"] = msgformat messagevars["format"] = msgformat
@ -445,15 +483,22 @@ class Mailer:
elif body: elif body:
messagevars["body_text_html"] = safe_text(body) messagevars["body_text_html"] = safe_text(body)
return messagevars def add_image_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
def add_image_message_vars(self, messagevars, event): ) -> None:
messagevars["image_url"] = event.content["url"] """
Potentially add an image URL to the message variables.
return messagevars """
if "url" in event.content:
messagevars["image_url"] = event.content["url"]
async def make_summary_text( async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason self,
notifs_by_room: Dict[str, List[Dict[str, Any]]],
room_state_ids: Dict[str, StateMap[str]],
notif_events: Dict[str, EventBase],
user_id: str,
reason: Dict[str, Any],
): ):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
@ -580,7 +625,7 @@ class Mailer:
"app": self.app_name, "app": self.app_name,
} }
def make_room_link(self, room_id): def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url: if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector": elif self.app_name == "Vector":
@ -590,7 +635,7 @@ class Mailer:
base_url = "https://matrix.to/#" base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id) return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif): def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url: if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % ( return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url, self.hs.config.email_riot_base_url,
@ -606,7 +651,9 @@ class Mailer:
else: else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
def make_unsubscribe_link(self, user_id, app_id, email_address): def make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str
) -> str:
params = { params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id, "app_id": app_id,
@ -620,7 +667,7 @@ class Mailer:
) )
def safe_markup(raw_html): def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup( return jinja2.Markup(
bleach.linkify( bleach.linkify(
bleach.clean( bleach.clean(
@ -635,7 +682,7 @@ def safe_markup(raw_html):
) )
def safe_text(raw_text): def safe_text(raw_text: str) -> jinja2.Markup:
""" """
Process text: treat it as HTML but escape any tags (ie. just escape the Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it. HTML) then linkify it.
@ -655,7 +702,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret return ret
def string_ordinal_total(s): def string_ordinal_total(s: str) -> int:
tot = 0 tot = 0
for c in s: for c in s:
tot += ord(c) tot += ord(c)

View file

@ -15,8 +15,14 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name( async def calculate_room_name(
store, store: "DataStore",
room_state_ids, room_state_ids: StateMap[str],
user_id, user_id: str,
fallback_to_members=True, fallback_to_members: bool = True,
fallback_to_single_member=True, fallback_to_single_member: bool = True,
): ) -> Optional[str]:
""" """
Works out a user-facing name for the given room as per Matrix Works out a user-facing name for the given room as per Matrix
spec recommendations. spec recommendations.
Does not yet support internationalisation. Does not yet support internationalisation.
Args: Args:
room_state: Dictionary of the room's state store: The data store to query.
room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no based on the room's members if the room has no
title or aliases. title or aliases.
fallback_to_single_member: If False, return None instead of generating a
name based on the user who invited this user to the room if the room
has no title or aliases.
Returns: Returns:
(string or None) A human readable name for the room. A human readable name for the room, if possible.
""" """
# does it have a name? # does it have a name?
if (EventTypes.Name, "") in room_state_ids: if (EventTypes.Name, "") in room_state_ids:
@ -97,7 +107,7 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event), name_from_member_event(inviter_member_event),
) )
else: else:
return return None
else: else:
return "Room Invite" return "Room Invite"
@ -150,19 +160,19 @@ async def calculate_room_name(
else: else:
return ALL_ALONE return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member: elif len(other_members) == 1 and not fallback_to_single_member:
return return None
else:
return descriptor_from_member_events(other_members) return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events): def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events. """Get a description of the room based on the member events.
Args: Args:
member_events (Iterable[FrozenEvent]) member_events: The events of a room.
Returns: Returns:
str The room description
""" """
member_events = list(member_events) member_events = list(member_events)
@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
) )
def name_from_member_event(member_event): def name_from_member_event(member_event: EventBase) -> str:
if ( if (
member_event.content member_event.content
and "displayname" in member_event.content and "displayname" in member_event.content
@ -193,12 +203,12 @@ def name_from_member_event(member_event):
return member_event.state_key return member_event.state_key
def _state_as_two_level_dict(state): def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
ret = {} ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items(): for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v ret.setdefault(k[0], {})[k[1]] = v
return ret return ret
def _looks_like_an_alias(string): def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None return ALIAS_RE.match(string) is not None

View file

@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count): def _room_member_count(
ev: EventBase, condition: Dict[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count) return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(ev, condition, sender_power_level, power_levels): def _sender_notification_permission(
ev: EventBase,
condition: Dict[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
notif_level_key = condition.get("key") notif_level_key = condition.get("key")
if notif_level_key is None: if notif_level_key is None:
return False return False
notif_levels = power_levels.get("notifications", {}) notif_levels = power_levels.get("notifications", {})
assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50) room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number): def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition: if "is" not in condition:
return False return False
m = INEQUALITY_EXPR.match(condition["is"]) m = INEQUALITY_EXPR.match(condition["is"])
@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase, event: EventBase,
room_member_count: int, room_member_count: int,
sender_power_level: int, sender_power_level: int,
power_levels: dict, power_levels: Dict[str, Union[int, Dict[str, int]]],
): ):
self._event = event self._event = event
self._room_member_count = room_member_count self._room_member_count = room_member_count
@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches(self, condition: dict, user_id: str, display_name: str) -> bool: def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str
) -> bool:
if condition["kind"] == "event_match": if condition["kind"] == "event_match":
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name": elif condition["kind"] == "contains_display_name":
@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,) return r"(^|\W)%s(\W|$)" % (r,)
def _flatten_dict(d, prefix=[], result=None): def _flatten_dict(
d: Union[EventBase, dict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
if prefix is None:
prefix = []
if result is None: if result is None:
result = {} result = {}
for key, value in d.items(): for key, value in d.items():

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage import Storage
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge return badge
async def get_context_for_event(storage: Storage, state_handler, ev, user_id): async def get_context_for_event(
storage: Storage, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {} ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)

View file

@ -14,25 +14,31 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional
from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
from .httppusher import HttpPusher if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherFactory: class PusherFactory:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.config = hs.config self.config = hs.config
self.pusher_types = {"http": HttpPusher} self.pusher_types = {
"http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text self._notif_template_text = hs.config.email_notif_template_text
@ -41,16 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type") logger.info("defined email pusher type")
def create_pusher(self, pusherdict): def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
kind = pusherdict["kind"] kind = pusher_config.kind
f = self.pusher_types.get(kind, None) f = self.pusher_types.get(kind, None)
if not f: if not f:
return None return None
logger.debug("creating %s pusher for %r", kind, pusherdict) logger.debug("creating %s pusher for %r", kind, pusher_config)
return f(self.hs, pusherdict) return f(self.hs, pusher_config)
def _create_email_pusher(self, _hs, pusherdict): def _create_email_pusher(
app_name = self._app_name_from_pusherdict(pusherdict) self, _hs: "HomeServer", pusher_config: PusherConfig
) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name) mailer = self.mailers.get(app_name)
if not mailer: if not mailer:
mailer = Mailer( mailer = Mailer(
@ -60,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text, template_text=self._notif_template_text,
) )
self.mailers[app_name] = mailer self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer) return EmailPusher(self.hs, pusher_config, mailer)
def _app_name_from_pusherdict(self, pusherdict): def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
data = pusherdict["data"] data = pusher_config.data
if isinstance(data, dict): if isinstance(data, dict):
brand = data.get("brand") brand = data.get("brand")

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Union from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge from prometheus_client import Gauge
@ -23,11 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.push import PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING: if TYPE_CHECKING:
@ -77,9 +75,9 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher # map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self): def start(self) -> None:
"""Starts the pushers off in a background process. """Starts the pushers off in a background process.
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
@ -89,52 +87,53 @@ class PusherPool:
async def add_pusher( async def add_pusher(
self, self,
user_id, user_id: str,
access_token, access_token: Optional[int],
kind, kind: str,
app_id, app_id: str,
app_display_name, app_display_name: str,
device_display_name, device_display_name: str,
pushkey, pushkey: str,
lang, lang: Optional[str],
data, data: JsonDict,
profile_tag="", profile_tag: str = "",
): ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
Returns: Returns:
EmailPusher|HttpPusher The newly created pusher.
""" """
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering, so it will process pushes from this point onwards.
last_stream_ordering = self.store.get_room_max_stream_ordering()
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
# code path adding pushers. # code path adding pushers.
self.pusher_factory.create_pusher( self.pusher_factory.create_pusher(
{ PusherConfig(
"id": None, id=None,
"user_name": user_id, user_name=user_id,
"kind": kind, access_token=access_token,
"app_id": app_id, profile_tag=profile_tag,
"app_display_name": app_display_name, kind=kind,
"device_display_name": device_display_name, app_id=app_id,
"pushkey": pushkey, app_display_name=app_display_name,
"ts": time_now_msec, device_display_name=device_display_name,
"lang": lang, pushkey=pushkey,
"data": data, ts=time_now_msec,
"last_stream_ordering": None, lang=lang,
"last_success": None, data=data,
"failing_since": None, last_stream_ordering=last_stream_ordering,
} last_success=None,
failing_since=None,
)
) )
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
await self.store.add_pusher( await self.store.add_pusher(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
@ -154,43 +153,44 @@ class PusherPool:
return pusher return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user( async def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id self, app_id: str, pushkey: str, not_user_id: str
): ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove: for p in to_remove:
if p["user_name"] != not_user_id: if p.user_name != not_user_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
app_id, app_id,
pushkey, pushkey,
p["user_name"], p.user_name,
) )
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
async def remove_pushers_by_access_token(self, user_id, access_tokens): async def remove_pushers_by_access_token(
self, user_id: str, access_tokens: Iterable[int]
) -> None:
"""Remove the pushers for a given user corresponding to a set of """Remove the pushers for a given user corresponding to a set of
access_tokens. access_tokens.
Args: Args:
user_id (str): user to remove pushers for user_id: user to remove pushers for
access_tokens (Iterable[int]): access token *ids* to remove pushers access_tokens: access token *ids* to remove pushers for
for
""" """
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id): for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens: if p.access_token in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p["app_id"], p.app_id,
p["pushkey"], p.pushkey,
p["user_name"], p.user_name,
) )
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -209,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token) self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications") @wrap_as_background_process("on_new_notifications")
async def _on_new_notifications(self, max_token: RoomStreamToken): async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock # We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector # component. This is safe to do as long as we *always* ignore the vector
# clock components. # clock components.
@ -239,7 +239,9 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): async def on_new_receipts(
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -267,28 +269,30 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
async def start_pusher_by_id(self, app_id, pushkey, user_id): async def start_pusher_by_id(
self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it """Look up the details for the given pusher, and start it
Returns: Returns:
EmailPusher|HttpPusher|None: The pusher started, if any The pusher started, if any
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
return return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None pusher_config = None
for r in resultlist: for r in resultlist:
if r["user_name"] == user_id: if r.user_name == user_id:
pusher_dict = r pusher_config = r
pusher = None pusher = None
if pusher_dict: if pusher_config:
pusher = await self._start_pusher(pusher_dict) pusher = await self._start_pusher(pusher_config)
return pusher return pusher
@ -303,44 +307,44 @@ class PusherPool:
logger.info("Started pushers") logger.info("Started pushers")
async def _start_pusher(self, pusherdict): async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher """Start the given pusher
Args: Args:
pusherdict (dict): dict with the values pulled from the db table pusher_config: The pusher configuration with the values pulled from the db table
Returns: Returns:
EmailPusher|HttpPusher The newly created pusher or None.
""" """
if not self._pusher_shard_config.should_handle( if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"] self._instance_name, pusher_config.user_name
): ):
return return None
try: try:
p = self.pusher_factory.create_pusher(pusherdict) p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e: except PusherConfigException as e:
logger.warning( logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
pusherdict["id"], pusher_config.id,
pusherdict.get("user_name"), pusher_config.user_name,
pusherdict.get("app_id"), pusher_config.app_id,
pusherdict.get("pushkey"), pusher_config.pushkey,
e, e,
) )
return return None
except Exception: except Exception:
logger.exception( logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"], "Couldn't start pusher id %i: caught Exception", pusher_config.id,
) )
return return None
if not p: if not p:
return return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
byuser = self.pushers.setdefault(pusherdict["user_name"], {}) byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser: if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop() byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p byuser[appid_pushkey] = p
@ -350,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a # Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to # lot cheaper to do than actually fetching the exact rows we need to
# push. # push.
user_id = pusherdict["user_name"] user_id = pusher_config.user_name
last_stream_ordering = pusherdict["last_stream_ordering"] last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering: if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user( have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering user_id, last_stream_ordering
@ -365,7 +369,7 @@ class PusherPool:
return p return p
async def remove_pusher(self, app_id, pushkey, user_id): async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey) appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {}) byuser = self.pushers.get(user_id, {})

View file

@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
self._replication_secret = None
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
received_secret = parts[1].decode("ascii")
if self._replication_secret == received_secret:
# Success!
return
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod @abc.abstractmethod
async def _serialize_payload(**kwargs): async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request. """Static method that is called when creating a request.
@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
replication_secret = None
if hs.config.worker.worker_replication_secret:
replication_secret = hs.config.worker.worker_replication_secret.encode(
"ascii"
)
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress() @outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs): async def send_request(instance_name="master", **kwargs):
@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not. # the master, and so whether we should clean up or not.
while True: while True:
headers = {} # type: Dict[bytes, List[bytes]] headers = {} # type: Dict[bytes, List[bytes]]
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False) inject_active_span_byte_dict(headers, None, check_destination=False)
try: try:
result = await request_func(uri, data, headers=headers) result = await request_func(uri, data, headers=headers)
@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
""" """
url_args = list(self.PATH_ARGS) url_args = list(self.PATH_ARGS)
handler = self._handle_request
method = self.METHOD method = self.METHOD
if self.CACHE: if self.CACHE:
handler = self._cached_handler # type: ignore
url_args.append("txn_id") url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths( http_server.register_paths(
method, [pattern], handler, self.__class__.__name__, method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
) )
def _cached_handler(self, request, txn_id, **kwargs): def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks """Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that, if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response. otherwise calls `_handle_request` and caches its response.
@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the # We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well. # other PATH_ARGS as well.
assert self.CACHE # Check the authorization headers before handling the request.
if self._replication_secret:
self._check_auth(request)
return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) if self.CACHE:
txn_id = kwargs.pop("txn_id")
return self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
return self._handle_request(request, **kwargs)

View file

@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): async def _serialize_payload(
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
):
""" """
Args: Args:
device_id (str|None): Device ID to use, if None a new one is device_id (str|None): Device ID to use, if None a new one is
@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id, "device_id": device_id,
"initial_display_name": initial_display_name, "initial_display_name": initial_display_name,
"is_guest": is_guest, "is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
} }
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"] device_id = content["device_id"]
initial_display_name = content["initial_display_name"] initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"] is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost=is_appservice_ghost,
) )
return 200, {"device_id": device_id, "access_token": access_token} return 200, {"device_id": device_id, "access_token": access_token}

Some files were not shown because too many files have changed in this diff Show more