Merge remote-tracking branch 'upstream/release-v1.29.0'

This commit is contained in:
Tulir Asokan 2021-03-04 12:49:37 +02:00
commit adb990d8ba
126 changed files with 2399 additions and 765 deletions

3
.gitignore vendored
View file

@ -6,13 +6,14 @@
*.egg *.egg
*.egg-info *.egg-info
*.lock *.lock
*.pyc *.py[cod]
*.snap *.snap
*.tac *.tac
_trial_temp/ _trial_temp/
_trial_temp*/ _trial_temp*/
/out /out
.DS_Store .DS_Store
__pycache__/
# stuff that is likely to exist when you run a server locally # stuff that is likely to exist when you run a server locally
/*.db /*.db

View file

@ -1,3 +1,60 @@
Synapse 1.29.0rc1 (2021-03-04)
==============================
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
Features
--------
- Add rate limiters to cross-user key sharing requests. ([\#8957](https://github.com/matrix-org/synapse/issues/8957))
- Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel. ([\#8978](https://github.com/matrix-org/synapse/issues/8978))
- Add some configuration settings to make users' profile data more private. ([\#9203](https://github.com/matrix-org/synapse/issues/9203))
- The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. ([\#9372](https://github.com/matrix-org/synapse/issues/9372))
- Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. ([\#9383](https://github.com/matrix-org/synapse/issues/9383), [\#9385](https://github.com/matrix-org/synapse/issues/9385))
- Add support for regenerating thumbnails if they have been deleted but the original image is still stored. ([\#9438](https://github.com/matrix-org/synapse/issues/9438))
- Add support for `X-Forwarded-Proto` header when using a reverse proxy. ([\#9472](https://github.com/matrix-org/synapse/issues/9472), [\#9501](https://github.com/matrix-org/synapse/issues/9501), [\#9512](https://github.com/matrix-org/synapse/issues/9512), [\#9539](https://github.com/matrix-org/synapse/issues/9539))
Bugfixes
--------
- Fix a bug where users' pushers were not all deleted when they deactivated their account. ([\#9285](https://github.com/matrix-org/synapse/issues/9285), [\#9516](https://github.com/matrix-org/synapse/issues/9516))
- Fix a bug where a lot of unnecessary presence updates were sent when joining a room. ([\#9402](https://github.com/matrix-org/synapse/issues/9402))
- Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. ([\#9416](https://github.com/matrix-org/synapse/issues/9416))
- Fix a bug in single sign-on which could cause a "No session cookie found" error. ([\#9436](https://github.com/matrix-org/synapse/issues/9436))
- Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. ([\#9440](https://github.com/matrix-org/synapse/issues/9440))
- Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. ([\#9449](https://github.com/matrix-org/synapse/issues/9449))
- Fix deleting pushers when using sharded pushers. ([\#9465](https://github.com/matrix-org/synapse/issues/9465), [\#9466](https://github.com/matrix-org/synapse/issues/9466), [\#9479](https://github.com/matrix-org/synapse/issues/9479), [\#9536](https://github.com/matrix-org/synapse/issues/9536))
- Fix missing startup checks for the consistency of certain PostgreSQL sequences. ([\#9470](https://github.com/matrix-org/synapse/issues/9470))
- Fix a long-standing bug where the media repository could leak file descriptors while previewing media. ([\#9497](https://github.com/matrix-org/synapse/issues/9497))
- Properly purge the event chain cover index when purging history. ([\#9498](https://github.com/matrix-org/synapse/issues/9498))
- Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. ([\#9503](https://github.com/matrix-org/synapse/issues/9503))
- Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. ([\#9506](https://github.com/matrix-org/synapse/issues/9506))
- Prevent presence background jobs from running when presence is disabled. ([\#9530](https://github.com/matrix-org/synapse/issues/9530))
- Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. ([\#9537](https://github.com/matrix-org/synapse/issues/9537))
Improved Documentation
----------------------
- Update the example systemd config to propagate reloads to individual units. ([\#9463](https://github.com/matrix-org/synapse/issues/9463))
Internal Changes
----------------
- Add documentation and type hints to `parse_duration`. ([\#9432](https://github.com/matrix-org/synapse/issues/9432))
- Remove vestiges of `uploads_path` configuration setting. ([\#9462](https://github.com/matrix-org/synapse/issues/9462))
- Add a comment about systemd-python. ([\#9464](https://github.com/matrix-org/synapse/issues/9464))
- Test that we require validated email for email pushers. ([\#9496](https://github.com/matrix-org/synapse/issues/9496))
- Allow python to generate bytecode for synapse. ([\#9502](https://github.com/matrix-org/synapse/issues/9502))
- Fix incorrect type hints. ([\#9515](https://github.com/matrix-org/synapse/issues/9515), [\#9518](https://github.com/matrix-org/synapse/issues/9518))
- Add type hints to device and event report admin API. ([\#9519](https://github.com/matrix-org/synapse/issues/9519))
- Add type hints to user admin API. ([\#9521](https://github.com/matrix-org/synapse/issues/9521))
- Bump the versions of mypy and mypy-zope used for static type checking. ([\#9529](https://github.com/matrix-org/synapse/issues/9529))
Synapse 1.28.0 (2021-02-25) Synapse 1.28.0 (2021-02-25)
=========================== ===========================

View file

@ -85,6 +85,26 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.29.0
====================
Requirement for X-Forwarded-Proto header
----------------------------------------
When using Synapse with a reverse proxy (in particular, when using the
`x_forwarded` option on an HTTP listener), Synapse now expects to receive an
`X-Forwarded-Proto` header on incoming HTTP requests. If it is not set, Synapse
will log a warning on each received request.
To avoid the warning, administrators using a reverse proxy should ensure that
the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
indicate the protocol used by the client. See the `reverse proxy documentation
<docs/reverse_proxy.md>`_, where the example configurations have been updated to
show how to set this header.
(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
sets `X-Forwarded-Proto` by default.)
Upgrading to v1.27.0 Upgrading to v1.27.0
==================== ====================

View file

@ -58,10 +58,10 @@ trap "rm -r $tmpdir" EXIT
cp -r tests "$tmpdir" cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \ PYTHONPATH="$tmpdir" \
"${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
# build the config file # build the config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \ "${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \ --config-dir="/etc/matrix-synapse" \
--data-dir="/var/lib/matrix-synapse" | --data-dir="/var/lib/matrix-synapse" |
perl -pe ' perl -pe '
@ -87,7 +87,7 @@ PYTHONPATH="$tmpdir" \
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml" ' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
# build the log config file # build the log config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \ "${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_log_config" \
--output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml" --output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml"
# add a dependency on the right version of python to substvars. # add a dependency on the right version of python to substvars.

7
debian/changelog vendored
View file

@ -1,3 +1,10 @@
matrix-synapse-py3 (1.29.0) UNRELEASED; urgency=medium
[ Jonathan de Jong ]
* Remove the python -B flag (don't generate bytecode) in scripts and documentation.
-- Synapse Packaging team <packages@matrix.org> Fri, 26 Feb 2021 14:41:31 +0100
matrix-synapse-py3 (1.28.0) stable; urgency=medium matrix-synapse-py3 (1.28.0) stable; urgency=medium
* New synapse release 1.28.0. * New synapse release 1.28.0.

2
debian/synctl.1 vendored
View file

@ -44,7 +44,7 @@ Configuration file may be generated as follows:
. .
.nf .nf
$ python \-B \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name> $ python \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
. .
.fi .fi
. .

2
debian/synctl.ronn vendored
View file

@ -41,7 +41,7 @@ process.
Configuration file may be generated as follows: Configuration file may be generated as follows:
$ python -B -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name> $ python -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
## ENVIRONMENT ## ENVIRONMENT

View file

@ -11,7 +11,6 @@ The image also does *not* provide a TURN server.
By default, the image expects a single volume, located at ``/data``, that will hold: By default, the image expects a single volume, located at ``/data``, that will hold:
* configuration files; * configuration files;
* temporary files during uploads;
* uploaded media and thumbnails; * uploaded media and thumbnails;
* the SQLite database if you do not configure postgres; * the SQLite database if you do not configure postgres;
* the appservices configuration. * the appservices configuration.

View file

@ -89,7 +89,6 @@ federation_rc_concurrent: 3
## Files ## ## Files ##
media_store_path: "/data/media" media_store_path: "/data/media"
uploads_path: "/data/uploads"
max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}" max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
max_image_pixels: "32M" max_image_pixels: "32M"
dynamic_thumbnails: false dynamic_thumbnails: false

View file

@ -379,11 +379,12 @@ The following fields are returned in the JSON response body:
- ``total`` - Number of rooms. - ``total`` - Number of rooms.
List media of an user List media of a user
================================ ====================
Gets a list of all local media that a specific ``user_id`` has created. Gets a list of all local media that a specific ``user_id`` has created.
The response is ordered by creation date descending and media ID descending. By default, the response is ordered by descending creation date and ascending media ID.
The newest media is on top. The newest media is on top. You can change the order with parameters
``order_by`` and ``dir``.
The API is:: The API is::
@ -440,6 +441,35 @@ The following parameters should be set in the URL:
denoting the offset in the returned results. This should be treated as an opaque value and denoting the offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token`` from a previous call. not explicitly set to anything other than the return value of ``next_token`` from a previous call.
Defaults to ``0``. Defaults to ``0``.
- ``order_by`` - The method by which to sort the returned list of media.
If the ordered field has duplicates, the second order is always by ascending ``media_id``,
which guarantees a stable ordering. Valid values are:
- ``media_id`` - Media are ordered alphabetically by ``media_id``.
- ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with.
- ``created_ts`` - Media are ordered by when the content was uploaded in ms.
Smallest to largest. This is the default.
- ``last_access_ts`` - Media are ordered by when the content was last accessed in ms.
Smallest to largest.
- ``media_length`` - Media are ordered by length of the media in bytes.
Smallest to largest.
- ``media_type`` - Media are ordered alphabetically by MIME-type.
- ``quarantined_by`` - Media are ordered alphabetically by the user ID that
initiated the quarantine request for this media.
- ``safe_from_quarantine`` - Media are ordered by the status if this media is safe
from quarantining.
- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top
(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``).
Caution. The database only has indexes on the columns ``media_id``,
``user_id`` and ``created_ts``. This means that if a different sort order is used
(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``,
``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the
database, especially for large environments.
**Response** **Response**

View file

@ -9,23 +9,23 @@ of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root (443) to Matrix clients without needing to run Synapse with root
privileges. privileges.
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise` You should configure your reverse proxy to forward requests to `/_matrix` or
the requested URI in any way (for example, by decoding `%xx` escapes). `/_synapse/client` to Synapse, and have it set the `X-Forwarded-For` and
Beware that Apache *will* canonicalise URIs unless you specify `X-Forwarded-Proto` request headers.
`nocanon`.
When setting up a reverse proxy, remember that Matrix clients and other You should remember that Matrix clients and other Matrix servers do not
Matrix servers do not necessarily need to connect to your server via the necessarily need to connect to your server via the same server name or
same server name or port. Indeed, clients will use port 443 by default, port. Indeed, clients will use port 443 by default, whereas servers default to
whereas servers default to port 8448. Where these are different, we port 8448. Where these are different, we refer to the 'client port' and the
refer to the 'client port' and the 'federation port'. See [the Matrix 'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation. [delegate.md](<delegate.md>) for instructions on setting up delegation.
Endpoints that are part of the standardised Matrix specification are **NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
located under `/_matrix`, whereas endpoints specific to Synapse are the requested URI in any way (for example, by decoding `%xx` escapes).
located under `/_synapse/client`. Beware that Apache *will* canonicalise URIs unless you specify
`nocanon`.
Let's assume that we expect clients to connect to our server at Let's assume that we expect clients to connect to our server at
`https://matrix.example.com`, and other servers to connect at `https://matrix.example.com`, and other servers to connect at
@ -52,6 +52,9 @@ server {
location ~* ^(\/_matrix|\/_synapse\/client) { location ~* ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008; proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-For $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header Host $host;
# Nginx by default only allows file uploads up to 1M in size # Nginx by default only allows file uploads up to 1M in size
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
client_max_body_size 50M; client_max_body_size 50M;
@ -102,6 +105,7 @@ example.com:8448 {
SSLEngine on SSLEngine on
ServerName matrix.example.com; ServerName matrix.example.com;
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
@ -113,6 +117,7 @@ example.com:8448 {
SSLEngine on SSLEngine on
ServerName example.com; ServerName example.com;
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
@ -134,6 +139,9 @@ example.com:8448 {
``` ```
frontend https frontend https
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src]
# Matrix client traffic # Matrix client traffic
acl matrix-host hdr(host) -i matrix.example.com acl matrix-host hdr(host) -i matrix.example.com
@ -144,6 +152,10 @@ frontend https
frontend matrix-federation frontend matrix-federation
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src]
default_backend matrix default_backend matrix
backend matrix backend matrix

View file

@ -101,6 +101,14 @@ pid_file: DATADIR/homeserver.pid
# #
#limit_profile_requests_to_users_who_share_rooms: true #limit_profile_requests_to_users_who_share_rooms: true
# Uncomment to prevent a user's profile data from being retrieved and
# displayed in a room until they have joined it. By default, a user's
# profile data is included in an invite event, regardless of the values
# of the above two settings, and whether or not the users share a server.
# Defaults to 'true'.
#
#include_profile_data_on_invite: false
# If set to 'true', removes the need for authentication to access the server's # If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can # public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'. # query the room directory. Defaults to 'false'.
@ -699,6 +707,12 @@ acme:
# - matrix.org # - matrix.org
# - example.com # - example.com
# Uncomment to disable profile lookup over federation. By default, the
# Federation API allows other homeservers to obtain profile data of any user
# on this homeserver. Defaults to 'true'.
#
#allow_profile_lookup_over_federation: false
## Caching ## ## Caching ##
@ -2530,19 +2544,35 @@ spam_checker:
# User Directory configuration # User Directory configuration
# #
# 'enabled' defines whether users can search the user directory. If user_directory:
# false then empty responses are returned to all queries. Defaults to # Defines whether users can search the user directory. If false then
# true. # empty responses are returned to all queries. Defaults to true.
# #
# 'search_all_users' defines whether to search all users visible to your HS # Uncomment to disable the user directory.
# when searching the user directory, rather than limiting to users visible #
# in public rooms. Defaults to false. If you set it True, you'll have to #enabled: false
# rebuild the user_directory search indexes, see
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # Defines whether to search all users visible to your HS when searching
# # the user directory, rather than limiting to users visible in public
#user_directory: # rooms. Defaults to false.
# enabled: true #
# search_all_users: false # If you set it true, you'll have to rebuild the user_directory search
# indexes, see:
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
#
#search_all_users: true
# Defines whether to prefer local users in search query results.
# If True, local users are more likely to appear above remote users
# when searching the user directory. Defaults to false.
#
# Uncomment to prefer local over remote users in user directory search
# results.
#
#prefer_local_users: true
# User Consent configuration # User Consent configuration

View file

@ -25,7 +25,7 @@ well as some specific methods:
* `check_username_for_spam` * `check_username_for_spam`
* `check_registration_for_spam` * `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs) The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class. are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to The `ModuleApi` class provides a way for the custom spam checker class to

View file

@ -4,6 +4,7 @@ AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
# This service should be restarted when the synapse target is restarted. # This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target PartOf=matrix-synapse.target
ReloadPropagatedFrom=matrix-synapse.target
# if this is started at the same time as the main, let the main process start # if this is started at the same time as the main, let the main process start
# first, to initialise the database schema. # first, to initialise the database schema.

View file

@ -3,6 +3,7 @@ Description=Synapse master
# This service should be restarted when the synapse target is restarted. # This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target PartOf=matrix-synapse.target
ReloadPropagatedFrom=matrix-synapse.target
[Service] [Service]
Type=notify Type=notify

View file

@ -220,10 +220,6 @@ Asks the server for the current position of all streams.
Acknowledge receipt of some federation data Acknowledge receipt of some federation data
#### REMOVE_PUSHER (C)
Inform the server a pusher should be removed
### REMOTE_SERVER_UP (S, C) ### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online. Inform other processes that a remote server may have come back online.

View file

@ -22,7 +22,7 @@ import logging
import sys import sys
import time import time
import traceback import traceback
from typing import Dict, Optional, Set from typing import Dict, Iterable, Optional, Set
import yaml import yaml
@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import ( from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore, MediaRepositoryBackgroundUpdateStore,
) )
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
@ -177,6 +178,7 @@ class Store(
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
EndToEndKeyBackgroundStore, EndToEndKeyBackgroundStore,
StatsStore, StatsStore,
PusherWorkerStore,
): ):
def execute(self, f, *args, **kwargs): def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
@ -629,7 +631,13 @@ class Porter(object):
await self._setup_state_group_id_seq() await self._setup_state_group_id_seq()
await self._setup_user_id_seq() await self._setup_user_id_seq()
await self._setup_events_stream_seqs() await self._setup_events_stream_seqs()
await self._setup_device_inbox_seq() await self._setup_sequence(
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
)
await self._setup_sequence(
"account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data"))
await self._setup_sequence("receipts_sequence", ("receipts_linearized", ))
await self._setup_auth_chain_sequence()
# Step 3. Get tables. # Step 3. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
@ -854,7 +862,7 @@ class Porter(object):
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self): async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
) )
@ -868,7 +876,7 @@ class Porter(object):
await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
async def _setup_user_id_seq(self): async def _setup_user_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.runInteraction( curr_id = await self.sqlite_store.db_pool.runInteraction(
"setup_user_id_seq", find_max_generated_user_id_localpart "setup_user_id_seq", find_max_generated_user_id_localpart
) )
@ -877,9 +885,9 @@ class Porter(object):
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
async def _setup_events_stream_seqs(self): async def _setup_events_stream_seqs(self) -> None:
"""Set the event stream sequences to the correct values. """Set the event stream sequences to the correct values.
""" """
@ -908,35 +916,46 @@ class Porter(object):
(curr_backward_id + 1,), (curr_backward_id + 1,),
) )
return await self.postgres_store.db_pool.runInteraction( await self.postgres_store.db_pool.runInteraction(
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
) )
async def _setup_device_inbox_seq(self): async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
"""Set the device inbox sequence to the correct value. """Set a sequence to the correct value.
""" """
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( current_stream_ids = []
table="device_inbox", for stream_id_table in stream_id_tables:
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table=stream_id_table,
keyvalues={}, keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)", retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True, allow_none=True,
) )
current_stream_ids.append(max_stream_id)
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( next_id = max(current_stream_ids) + 1
table="device_federation_outbox",
keyvalues={}, def r(txn):
retcol="COALESCE(MAX(stream_id), 1)", sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
allow_none=True, txn.execute(sql + " %s", (next_id, ))
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
) )
next_id = max(curr_local_id, curr_federation_id) + 1
def r(txn): def r(txn):
txn.execute( txn.execute(
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id,),
)
await self.postgres_store.db_pool.runInteraction(
"_setup_event_auth_chain_id", r,
) )
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
############################################## ##############################################

View file

@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8", "flake8",
] ]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
# Dependencies which are exclusively required by unit test code. This is # Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests. # NOT a list of all modules that are necessary to run the unit tests.

View file

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.28.0" __version__ = "1.29.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -98,11 +98,14 @@ class EventTypes:
Retention = "m.room.retention" Retention = "m.room.retention"
Presence = "m.presence"
Dummy = "org.matrix.dummy_event" Dummy = "org.matrix.dummy_event"
class EduTypes:
Presence = "m.presence"
RoomKeyRequest = "m.room_key_request"
class RejectedReason: class RejectedReason:
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Tuple from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.types import Requester from synapse.types import Requester
@ -42,7 +42,9 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * The point in time # * The point in time
# * The rate_hz of this particular entry. This can vary per request # * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] self.actions = (
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
def can_requester_do_action( def can_requester_do_action(
self, self,
@ -82,7 +84,7 @@ class Ratelimiter:
def can_do_action( def can_do_action(
self, self,
key: Any, key: Hashable,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
@ -175,7 +177,7 @@ class Ratelimiter:
def ratelimit( def ratelimit(
self, self,
key: Any, key: Hashable,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,

View file

@ -17,8 +17,6 @@ import sys
from synapse import python_dependencies # noqa: E402 from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:

View file

@ -210,7 +210,9 @@ def start(config_options):
config.update_user_directory = False config.update_user_directory = False
config.run_background_tasks = False config.run_background_tasks = False
config.start_pushers = False config.start_pushers = False
config.pusher_shard_config.instances = []
config.send_federation = False config.send_federation = False
config.federation_shard_config.instances = []
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View file

@ -23,6 +23,7 @@ from typing_extensions import ContextManager
from twisted.internet import address from twisted.internet import address
from twisted.web.resource import IResource from twisted.web.resource import IResource
from twisted.web.server import Request
import synapse import synapse
import synapse.events import synapse.events
@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet):
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.main_uri = hs.config.worker_main_http_uri self.main_uri = hs.config.worker_main_http_uri
async def on_POST(self, request, device_id): async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet):
header: request.requestHeaders.getRawHeaders(header, []) header: request.requestHeaders.getRawHeaders(header, [])
for header in (b"Authorization", b"User-Agent") for header in (b"Authorization", b"User-Agent")
} }
# Add the previous hop the the X-Forwarded-For header. # Add the previous hop to the X-Forwarded-For header.
x_forwarded_for = request.requestHeaders.getRawHeaders( x_forwarded_for = request.requestHeaders.getRawHeaders(
b"X-Forwarded-For", [] b"X-Forwarded-For", []
) )
# we use request.client here, since we want the previous hop, not the
# original client (as returned by request.getClientAddress()).
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
previous_host = request.client.host.encode("ascii") previous_host = request.client.host.encode("ascii")
# If the header exists, add to the comma-separated list of the first # If the header exists, add to the comma-separated list of the first
@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet):
x_forwarded_for = [previous_host] x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for headers[b"X-Forwarded-For"] = x_forwarded_for
# Replicate the original X-Forwarded-Proto header. Note that
# XForwardedForRequest overrides isSecure() to give us the original protocol
# used by the client, as opposed to the protocol used by our upstream proxy
# - which is what we want here.
headers[b"X-Forwarded-Proto"] = [
b"https" if request.isSecure() else b"http"
]
try: try:
result = await self.http_client.post_json_get_json( result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers self.main_uri + request.uri.decode("ascii"), body, headers=headers
@ -645,9 +656,6 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)
async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
@cache_in_self @cache_in_self
def get_replication_data_handler(self): def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self) return GenericWorkerReplicationHandler(self)
@ -922,22 +930,6 @@ def start(config_options):
# For other worker types we force this to off. # For other worker types we force this to off.
config.appservice.notify_appservices = False config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.server.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``start_pushers: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.server.start_pushers = True
else:
# For other worker types we force this to off.
config.server.start_pushers = False
if config.worker_app == "synapse.app.user_dir": if config.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory: if config.server.update_user_directory:
sys.stderr.write( sys.stderr.write(
@ -954,22 +946,6 @@ def start(config_options):
# For other worker types we force this to off. # For other worker types we force this to off.
config.server.update_user_directory = False config.server.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.worker.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``send_federation: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.worker.send_federation = True
else:
# For other worker types we force this to off.
config.worker.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
hs = GenericWorkerServer( hs = GenericWorkerServer(

View file

@ -21,7 +21,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
from typing import Any, Iterable, List, MutableMapping, Optional from typing import Any, Iterable, List, MutableMapping, Optional, Union
import attr import attr
import jinja2 import jinja2
@ -147,7 +147,20 @@ class Config:
return int(value) * size return int(value) * size
@staticmethod @staticmethod
def parse_duration(value): def parse_duration(value: Union[str, int]) -> int:
"""Convert a duration as a string or integer to a number of milliseconds.
If an integer is provided it is treated as milliseconds and is unchanged.
String durations can have a suffix of 's', 'm', 'h', 'd', 'w', or 'y'.
No suffix is treated as milliseconds.
Args:
value: The duration to parse.
Returns:
The number of milliseconds in the duration.
"""
if isinstance(value, int): if isinstance(value, int):
return value return value
second = 1000 second = 1000
@ -831,22 +844,23 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool: def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.""" """Whether this instance is responsible for handling the given key."""
# If multiple instances are not defined we always return true # If no instances are defined we assume some other worker is handling
if not self.instances or len(self.instances) == 1: # this.
return True if not self.instances:
return False
return self.get_instance(key) == instance_name return self._get_instance(key) == instance_name
def get_instance(self, key: str) -> str: def _get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key. """Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance Note: For federation sending and pushers the config for which instance
is sending is known only to the sender instance if there is only one. is sending is known only to the sender instance, so we don't expose this
Therefore `should_handle` should be used where possible. method by default.
""" """
if not self.instances: if not self.instances:
return "master" raise Exception("Unknown worker")
if len(self.instances) == 1: if len(self.instances) == 1:
return self.instances[0] return self.instances[0]
@ -863,4 +877,21 @@ class ShardedWorkerHandlingConfig:
return self.instances[remainder] return self.instances[remainder]
@attr.s
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
"""A version of `ShardedWorkerHandlingConfig` that is used for config
options where all instances know which instances are responsible for the
sharded work.
"""
def __attrs_post_init__(self):
# We require that `self.instances` is non-empty.
if not self.instances:
raise Exception("Got empty list of instances for shard config")
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key."""
return self._get_instance(key)
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View file

@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig:
instances: List[str] instances: List[str]
def __init__(self, instances: List[str]) -> None: ... def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ... def should_handle(self, instance_name: str, key: str) -> bool: ...
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ... def get_instance(self, key: str) -> str: ...

View file

@ -41,6 +41,10 @@ class FederationConfig(Config):
) )
self.federation_metrics_domains = set(federation_metrics_domains) self.federation_metrics_domains = set(federation_metrics_domains)
self.allow_profile_lookup_over_federation = config.get(
"allow_profile_lookup_over_federation", True
)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
## Federation ## ## Federation ##
@ -66,6 +70,12 @@ class FederationConfig(Config):
#federation_metrics_domains: #federation_metrics_domains:
# - matrix.org # - matrix.org
# - example.com # - example.com
# Uncomment to disable profile lookup over federation. By default, the
# Federation API allows other homeservers to obtain profile data of any user
# on this homeserver. Defaults to 'true'.
#
#allow_profile_lookup_over_federation: false
""" """

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ShardedWorkerHandlingConfig from ._base import Config
class PushConfig(Config): class PushConfig(Config):
@ -27,9 +27,6 @@ class PushConfig(Config):
"group_unread_count_by_room", True "group_unread_count_by_room", True
) )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# There was a a 'redact_content' setting but mistakenly read from the # There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log, # 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade. # but do not honour it to avoid nasty surprises when people upgrade.

View file

@ -102,6 +102,16 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3}, defaults={"per_second": 0.01, "burst_count": 3},
) )
# Ratelimit cross-user key requests:
# * For local requests this is keyed by the sending device.
# * For requests received over federation this is keyed by the origin.
#
# Note that this isn't exposed in the configuration as it is obscure.
self.rc_key_requests = RateLimitConfig(
config.get("rc_key_requests", {}),
defaults={"per_second": 20, "burst_count": 100},
)
self.rc_3pid_validation = RateLimitConfig( self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {}, config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},

View file

@ -207,7 +207,6 @@ class ContentRepositoryConfig(Config):
def generate_config_section(self, data_dir_path, **kwargs): def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store") media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
formatted_thumbnail_sizes = "".join( formatted_thumbnail_sizes = "".join(
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES

View file

@ -263,6 +263,12 @@ class ServerConfig(Config):
False, False,
) )
# Whether to retrieve and display profile data for a user when they
# are invited to a room
self.include_profile_data_on_invite = config.get(
"include_profile_data_on_invite", True
)
if "restrict_public_rooms_to_local_users" in config and ( if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config "allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config or "allow_public_rooms_over_federation" in config
@ -391,7 +397,6 @@ class ServerConfig(Config):
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/": if self.public_baseurl[-1] != "/":
self.public_baseurl += "/" self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit, # (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before # for testing. The value defines the number of milliseconds to pause before
@ -848,6 +853,14 @@ class ServerConfig(Config):
# #
#limit_profile_requests_to_users_who_share_rooms: true #limit_profile_requests_to_users_who_share_rooms: true
# Uncomment to prevent a user's profile data from being retrieved and
# displayed in a room until they have joined it. By default, a user's
# profile data is included in an invite event, regardless of the values
# of the above two settings, and whether or not the users share a server.
# Defaults to 'true'.
#
#include_profile_data_on_invite: false
# If set to 'true', removes the need for authentication to access the server's # If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can # public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'. # query the room directory. Defaults to 'false'.

View file

@ -24,32 +24,46 @@ class UserDirectoryConfig(Config):
section = "userdirectory" section = "userdirectory"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True user_directory_config = config.get("user_directory") or {}
self.user_directory_search_all_users = False self.user_directory_search_enabled = user_directory_config.get("enabled", True)
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
"enabled", True
)
self.user_directory_search_all_users = user_directory_config.get( self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False "search_all_users", False
) )
self.user_directory_search_prefer_local_users = user_directory_config.get(
"prefer_local_users", False
)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """ return """
# User Directory configuration # User Directory configuration
# #
# 'enabled' defines whether users can search the user directory. If user_directory:
# false then empty responses are returned to all queries. Defaults to # Defines whether users can search the user directory. If false then
# true. # empty responses are returned to all queries. Defaults to true.
# #
# 'search_all_users' defines whether to search all users visible to your HS # Uncomment to disable the user directory.
# when searching the user directory, rather than limiting to users visible #
# in public rooms. Defaults to false. If you set it True, you'll have to #enabled: false
# rebuild the user_directory search indexes, see
# Defines whether to search all users visible to your HS when searching
# the user directory, rather than limiting to users visible in public
# rooms. Defaults to false.
#
# If you set it true, you'll have to rebuild the user_directory search
# indexes, see:
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
# #
#user_directory: # Uncomment to return search results containing all known users, even if that
# enabled: true # user does not share a room with the requester.
# search_all_users: false #
#search_all_users: true
# Defines whether to prefer local users in search query results.
# If True, local users are more likely to appear above remote users
# when searching the user directory. Defaults to false.
#
# Uncomment to prefer local over remote users in user directory search
# results.
#
#prefer_local_users: true
""" """

View file

@ -17,9 +17,28 @@ from typing import List, Union
import attr import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from ._base import (
Config,
ConfigError,
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
from .server import ListenerConfig, parse_listener_def from .server import ListenerConfig, parse_listener_def
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
The send_federation config option must be disabled in the main
synapse process before they can be run in a separate worker.
Please add ``send_federation: false`` to the main config
"""
_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """
The start_pushers config option must be disabled in the main
synapse process before they can be run in a separate worker.
Please add ``start_pushers: false`` to the main config
"""
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config """Helper for allowing parsing a string or list of strings to a config
@ -103,6 +122,7 @@ class WorkerConfig(Config):
self.worker_replication_secret = config.get("worker_replication_secret", None) self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.instance_name = self.worker_name or "master"
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@ -118,12 +138,41 @@ class WorkerConfig(Config):
) )
) )
# Whether to send federation traffic out in this process. This only # Handle federation sender configuration.
# applies to some federation traffic, and so shouldn't be used to #
# "disable" federation # There are two ways of configuring which instances handle federation
self.send_federation = config.get("send_federation", True) # sending:
# 1. The old way where "send_federation" is set to false and running a
# `synapse.app.federation_sender` worker app.
# 2. Specifying the workers sending federation in
# `federation_sender_instances`.
#
federation_sender_instances = config.get("federation_sender_instances") or [] send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances")
if federation_sender_instances is None:
# Default to an empty list, which means "another, unknown, worker is
# responsible for it".
federation_sender_instances = []
# If no federation sender instances are set we check if
# `send_federation` is set, which means use master
if send_federation:
federation_sender_instances = ["master"]
if self.worker_app == "synapse.app.federation_sender":
if send_federation:
# If we're running federation senders, and not using
# `federation_sender_instances`, then we should have
# explicitly set `send_federation` to false.
raise ConfigError(
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR
)
federation_sender_instances = [self.worker_name]
self.send_federation = self.instance_name in federation_sender_instances
self.federation_shard_config = ShardedWorkerHandlingConfig( self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances federation_sender_instances
) )
@ -164,7 +213,37 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `receipts` messages." "Must only specify one instance to handle `receipts` messages."
) )
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) if len(self.writers.events) == 0:
raise ConfigError("Must specify at least one instance to handle `events`.")
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)
# Handle sharded push
start_pushers = config.get("start_pushers", True)
pusher_instances = config.get("pusher_instances")
if pusher_instances is None:
# Default to an empty list, which means "another, unknown, worker is
# responsible for it".
pusher_instances = []
# If no pushers instances are set we check if `start_pushers` is
# set, which means use master
if start_pushers:
pusher_instances = ["master"]
if self.worker_app == "synapse.app.pusher":
if start_pushers:
# If we're running pushers, and not using
# `pusher_instances`, then we should have explicitly set
# `start_pushers` to false.
raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR)
pusher_instances = [self.instance_name]
self.start_pushers = self.instance_name in pusher_instances
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# Whether this worker should run background tasks or not. # Whether this worker should run background tasks or not.
# #

View file

@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -44,6 +44,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@ -869,6 +870,13 @@ class FederationHandlerRegistry:
# EDU received. # EDU received.
self._edu_type_to_instance = {} # type: Dict[str, List[str]] self._edu_type_to_instance = {} # type: Dict[str, List[str]]
# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
)
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ):
@ -917,7 +925,15 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict):
if not self.config.use_presence and edu_type == "m.presence": if not self.config.use_presence and edu_type == EduTypes.Presence:
return
# If the incoming room key requests from a particular origin are over
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
and not self._room_key_request_rate_limiter.can_do_action(origin)
):
return return
# Check if we have a handler on this instance # Check if we have a handler on this instance

View file

@ -474,7 +474,7 @@ class FederationSender:
self._processing_pending_presence = False self._processing_pending_presence = False
def send_presence_to_destinations( def send_presence_to_destinations(
self, states: List[UserPresenceState], destinations: List[str] self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None: ) -> None:
"""Send the given presence states to the given destinations. """Send the given presence states to the given destinations.
destinations (list[str]) destinations (list[str])

View file

@ -484,10 +484,9 @@ class FederationQueryServlet(BaseFederationServlet):
# This is when we receive a server-server Query # This is when we receive a server-server Query
async def on_GET(self, origin, content, query, query_type): async def on_GET(self, origin, content, query, query_type):
return await self.handler.on_query_request( args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
query_type, args["origin"] = origin
{k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, return await self.handler.on_query_request(query_type, args)
)
class FederationMakeJoinServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet):

View file

@ -36,7 +36,7 @@ import attr
import bcrypt import bcrypt
import pymacaroons import pymacaroons
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
@ -481,7 +481,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"] sid = authdict["session"]
# Convert the URI and method to strings. # Convert the URI and method to strings.
uri = request.uri.decode("utf-8") uri = request.uri.decode("utf-8") # type: ignore
method = request.method.decode("utf-8") method = request.method.decode("utf-8")
# If there's no session ID, create a new session. # If there's no session ID, create a new session.

View file

@ -120,6 +120,11 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None) await self.store.user_set_password_hash(user_id, None)
# Most of the pushers will have been deleted when we logged out the
# associated devices above, but we still need to delete pushers not
# associated with devices, e.g. email pushers.
await self.store.delete_all_pushers_for_user(user_id)
# Add the user to a table of users pending deactivation (ie. # Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of) # removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id) await self.store.add_user_pending_deactivation(user_id)

View file

@ -16,7 +16,9 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from synapse.api.constants import EduTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
get_active_span_text_map, get_active_span_text_map,
@ -25,7 +27,7 @@ from synapse.logging.opentracing import (
start_active_span, start_active_span,
) )
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -78,6 +80,12 @@ class DeviceMessageHandler:
ReplicationUserDevicesResyncRestServlet.make_client(hs) ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
self._ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
@ -168,15 +176,27 @@ class DeviceMessageHandler:
async def send_device_message( async def send_device_message(
self, self,
sender_user_id: str, requester: Requester,
message_type: str, message_type: str,
messages: Dict[str, Dict[str, JsonDict]], messages: Dict[str, Dict[str, JsonDict]],
) -> None: ) -> None:
sender_user_id = requester.user.to_string()
set_tag("number_of_messages", len(messages)) set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
and self._ratelimiter.can_do_action(
(sender_user_id, requester.device_id)
)
):
continue
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { messages_by_device = {

View file

@ -17,7 +17,7 @@ import logging
import random import random
from typing import TYPE_CHECKING, Iterable, List, Optional from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler):
states = await presence_handler.get_states(users) states = await presence_handler.get_states(users)
to_add.extend( to_add.extend(
{ {
"type": EventTypes.Presence, "type": EduTypes.Presence,
"content": format_user_presence_state(state, time_now), "content": format_user_presence_state(state, time_now),
} }
for state in states for state in states

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
return [ return [
{ {
"type": EventTypes.Presence, "type": EduTypes.Presence,
"content": format_user_presence_state(s, time_now), "content": format_user_presence_state(s, time_now),
} }
for s in states for s in states

View file

@ -387,6 +387,12 @@ class EventCreationHandler:
self.room_invite_state_types = self.hs.config.room_invite_state_types self.room_invite_state_types = self.hs.config.room_invite_state_types
self.membership_types_to_include_profile_data_in = (
{Membership.JOIN, Membership.INVITE}
if self.hs.config.include_profile_data_on_invite
else {Membership.JOIN}
)
self.send_event = ReplicationSendEventRestServlet.make_client(hs) self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users # This is only used to get at ratelimit function, and maybe_kick_guest_users
@ -502,7 +508,7 @@ class EventCreationHandler:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key) target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}: if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
profile = self.profile_handler profile = self.profile_handler
content = builder.content content = builder.content

View file

@ -274,6 +274,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
if self._presence_enabled:
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
@ -282,7 +283,9 @@ class PresenceHandler(BasePresenceHandler):
"handle_presence_timeouts", self._handle_timeouts "handle_presence_timeouts", self._handle_timeouts
) )
self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) self.clock.call_later(
30, self.clock.looping_call, run_timeout_handler, 5000
)
def run_persister(): def run_persister():
return run_as_background_process( return run_as_background_process(
@ -299,7 +302,7 @@ class PresenceHandler(BasePresenceHandler):
) )
# Used to handle sending of presence to newly joined users/servers # Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence: if self._presence_enabled:
self.notifier.add_replication_callback(self.notify_new_event) self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always # Presence is best effort and quickly heals itself, so lets just always
@ -849,6 +852,9 @@ class PresenceHandler(BasePresenceHandler):
"""Process current state deltas to find new joins that need to be """Process current state deltas to find new joins that need to be
handled. handled.
""" """
# A map of destination to a set of user state that they should receive
presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
for delta in deltas: for delta in deltas:
typ = delta["type"] typ = delta["type"]
state_key = delta["state_key"] state_key = delta["state_key"]
@ -858,6 +864,7 @@ class PresenceHandler(BasePresenceHandler):
logger.debug("Handling: %r %r, %s", typ, state_key, event_id) logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
# Drop any event that isn't a membership join
if typ != EventTypes.Member: if typ != EventTypes.Member:
continue continue
@ -880,13 +887,38 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events. # Ignore changes to join events.
continue continue
await self._on_user_joined_room(room_id, state_key) # Retrieve any user presence state updates that need to be sent as a result,
# and the destinations that need to receive it
destinations, user_presence_states = await self._on_user_joined_room(
room_id, state_key
)
async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: # Insert the destinations and respective updates into our destinations dict
for destination in destinations:
presence_destinations.setdefault(destination, set()).update(
user_presence_states
)
# Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items():
self.federation.send_presence_to_destinations(
destinations=[destination], states=user_state_set
)
async def _on_user_joined_room(
self, room_id: str, user_id: str
) -> Tuple[List[str], List[UserPresenceState]]:
"""Called when we detect a user joining the room via the current state """Called when we detect a user joining the room via the current state
delta stream. delta stream. Returns the destinations that need to be updated and the
""" presence updates to send to them.
Args:
room_id: The ID of the room that the user has joined.
user_id: The ID of the user that has joined the room.
Returns:
A tuple of destinations and presence updates to send to them.
"""
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
# If this is a local user then we need to send their presence # If this is a local user then we need to send their presence
# out to hosts in the room (who don't already have it) # out to hosts in the room (who don't already have it)
@ -894,15 +926,15 @@ class PresenceHandler(BasePresenceHandler):
# TODO: We should be able to filter the hosts down to those that # TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user # haven't previously seen the user
state = await self.current_state_for_user(user_id) remote_hosts = await self.state.get_current_hosts_in_room(room_id)
hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves. # Filter out ourselves.
hosts = {host for host in hosts if host != self.server_name} filtered_remote_hosts = [
host for host in remote_hosts if host != self.server_name
]
self.federation.send_presence_to_destinations( state = await self.current_state_for_user(user_id)
states=[state], destinations=hosts return filtered_remote_hosts, [state]
)
else: else:
# A remote user has joined the room, so we need to: # A remote user has joined the room, so we need to:
# 1. Check if this is a new server in the room # 1. Check if this is a new server in the room
@ -915,6 +947,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the # TODO: Check that this is actually a new server joining the
# room. # room.
remote_host = get_domain_from_id(user_id)
users = await self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users)) user_ids = list(filter(self.is_mine_id, users))
@ -934,10 +968,7 @@ class PresenceHandler(BasePresenceHandler):
or state.status_msg is not None or state.status_msg is not None
] ]
if states: return [remote_host], states
self.federation.send_presence_to_destinations(
states=states, destinations=[get_domain_from_id(user_id)]
)
def should_notify(old_state, new_state): def should_notify(old_state, new_state):

View file

@ -310,6 +310,15 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def on_profile_query(self, args: JsonDict) -> JsonDict: async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
if not self.hs.config.allow_profile_lookup_over_federation:
raise SynapseError(
403,
"Profile lookup over federation is disabled on this homeserver",
Codes.FORBIDDEN,
)
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")

View file

@ -31,8 +31,8 @@ from urllib.parse import urlencode
import attr import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError

View file

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Union
from twisted.internet import task from twisted.internet import address, task
from twisted.web.client import FileBodyProducer from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
pass pass
def get_request_uri(request: IRequest) -> bytes:
"""Return the full URI that was requested by the client"""
return b"%s://%s%s" % (
b"https" if request.isSecure() else b"http",
_get_requested_host(request),
# despite its name, "request.uri" is only the path and query-string.
request.uri,
)
def _get_requested_host(request: IRequest) -> bytes:
hostname = request.getHeader(b"host")
if hostname:
return hostname
# no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
hostname = host.host.encode("ascii")
if request.isSecure() and host.port == 443:
# default port for https
return hostname
if not request.isSecure() and host.port == 80:
# default port for http
return hostname
return b"%s:%i" % (
hostname,
host.port,
)
def get_request_user_agent(request: IRequest, default: str = "") -> str: def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.""" """Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header. # There could be raw utf-8 bytes in the User-Agent header.

View file

@ -289,9 +289,8 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {}, treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None, ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
user_agent: Optional[str] = None, user_agent: Optional[str] = None,
use_proxy: bool = False,
): ):
""" """
Args: Args:
@ -301,8 +300,8 @@ class SimpleHttpClient:
we may not request. we may not request.
ip_whitelist: The whitelisted IP addresses, that we can ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist. request if it were otherwise caught in a blacklist.
http_proxy: proxy server to use for http connections. host[:port] use_proxy: Whether proxy settings should be discovered and used
https_proxy: proxy server to use for https connections. host[:port] from conventional environment variables.
""" """
self.hs = hs self.hs = hs
@ -346,8 +345,7 @@ class SimpleHttpClient:
connectTimeout=15, connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(), contextFactory=self.hs.get_http_client_context_factory(),
pool=pool, pool=pool,
http_proxy=http_proxy, use_proxy=use_proxy,
https_proxy=https_proxy,
) )
if self._ip_blacklist: if self._ip_blacklist:
@ -751,7 +749,32 @@ class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded.""" """The maximum allowed size of the HTTP body was exceeded."""
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
def _maybe_fail(self):
"""
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
def __init__( def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
): ):
@ -808,13 +831,15 @@ def read_body_with_max_size(
Returns: Returns:
A Deferred which resolves to the length of the read body. A Deferred which resolves to the length of the read body.
""" """
d = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed # If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body. # size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH: if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size: if response.length > max_size:
return defer.fail(BodyExceededMaxSize()) response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d return d

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import List, Optional from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer from zope.interface import implementer
@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes, uri: bytes,
headers: Optional[Headers] = None, headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None, bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred: ) -> Generator[defer.Deferred, Any, defer.Deferred]:
""" """
Args: Args:
method: HTTP method: GET/POST/etc method: HTTP method: GET/POST/etc
@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the # We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided. # server and that a user-agent is provided.
if headers is None: if headers is None:
headers = Headers() request_headers = Headers()
else: else:
headers = headers.copy() request_headers = headers.copy()
if not headers.hasHeader(b"host"): if not request_headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc) request_headers.addRawHeader(b"host", parsed_uri.netloc)
if not headers.hasHeader(b"user-agent"): if not request_headers.hasHeader(b"user-agent"):
headers.addRawHeader(b"user-agent", self.user_agent) request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable( res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer) self._agent.request(method, uri, request_headers, bodyProducer)
) )
return res return res

View file

@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON RequestSendFailed: if the Content-Type header is missing or isn't JSON
""" """
c_type = headers.getRawHeaders(b"Content-Type") content_type_headers = headers.getRawHeaders(b"Content-Type")
if c_type is None: if content_type_headers is None:
raise RequestSendFailed( raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"), RuntimeError("No Content-Type header received from remote server"),
can_retry=False, can_retry=False,
) )
c_type = c_type[0].decode("ascii") # only the first header c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type) val, options = cgi.parse_header(c_type)
if val != "application/json": if val != "application/json":
raise RequestSendFailed( raise RequestSendFailed(

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from urllib.request import getproxies_environment, proxy_bypass_environment
from zope.interface import implementer from zope.interface import implementer
@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created. non-persistent pool instance will be created.
use_proxy (bool): Whether proxy settings should be discovered and used
from conventional environment variables.
""" """
def __init__( def __init__(
@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None, connectTimeout=None,
bindAddress=None, bindAddress=None,
pool=None, pool=None,
http_proxy=None, use_proxy=False,
https_proxy=None,
): ):
_AgentBase.__init__(self, reactor, pool) _AgentBase.__init__(self, reactor, pool)
@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None: if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress self._endpoint_kwargs["bindAddress"] = bindAddress
http_proxy = None
https_proxy = None
no_proxy = None
if use_proxy:
proxies = getproxies_environment()
http_proxy = proxies["http"].encode() if "http" in proxies else None
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
self.http_proxy_endpoint = _http_proxy_endpoint( self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs http_proxy, self.proxy_reactor, **self._endpoint_kwargs
) )
@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs https_proxy, self.proxy_reactor, **self._endpoint_kwargs
) )
self.no_proxy = no_proxy
self._policy_for_https = contextFactory self._policy_for_https = contextFactory
self._reactor = reactor self._reactor = reactor
@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm request_path = parsed_uri.originForm
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
parsed_uri.host.decode(),
proxies={"no": self.no_proxy},
)
if (
parsed_uri.scheme == b"http"
and self.http_proxy_endpoint
and not should_skip_proxy
):
# Cache *all* connections under the same key, since we are only # Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy: # connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint) pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint endpoint = self.http_proxy_endpoint
request_path = uri request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: elif (
parsed_uri.scheme == b"https"
and self.https_proxy_endpoint
and not should_skip_proxy
):
endpoint = HTTPConnectProxyEndpoint( endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor, self.proxy_reactor,
self.https_proxy_endpoint, self.https_proxy_endpoint,

View file

@ -21,6 +21,7 @@ import logging
import types import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO from io import BytesIO
from typing import ( from typing import (
Any, Any,
@ -30,6 +31,7 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
List, List,
Optional,
Pattern, Pattern,
Tuple, Tuple,
Union, Union,
@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.""" """Sends a JSON error response to clients."""
if f.check(SynapseError): if f.check(SynapseError):
error_code = f.value.code # mypy doesn't understand that f.check asserts the type.
error_dict = f.value.error_dict() exc = f.value # type: SynapseError # type: ignore
error_code = exc.code
error_dict = exc.error_dict()
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else: else:
error_code = 500 error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r", "Failed handle request via %r: %r",
request.request_metrics.name, request.request_metrics.name,
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
# Only respond with an error response if we haven't already started writing, # Only respond with an error response if we haven't already started writing,
@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template `{msg}` placeholders), or a jinja2 template
""" """
if f.check(CodeMessageException): if f.check(CodeMessageException):
cme = f.value # mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore
code = cme.code code = cme.code
msg = cme.msg msg = cme.msg
@ -142,7 +147,7 @@ def return_html_error(
logger.error( logger.error(
"Failed handle request %r", "Failed handle request %r",
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
else: else:
code = HTTPStatus.INTERNAL_SERVER_ERROR code = HTTPStatus.INTERNAL_SERVER_ERROR
@ -151,7 +156,7 @@ def return_html_error(
logger.error( logger.error(
"Failed handle request %r", "Failed handle request %r",
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
if isinstance(error_template, str): if isinstance(error_template, str):
@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request) raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now. # Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): if isawaitable(raw_callback_return):
callback_return = await raw_callback_return callback_return = await raw_callback_return
else: else:
callback_return = raw_callback_return # type: ignore callback_return = raw_callback_return # type: ignore
@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback key word arguments to pass to the callback
""" """
# At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests. # Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
request_method = request.method request_method = request.method
if request_method == b"HEAD": if request_method == b"HEAD":
request_method = b"GET" request_method = b"GET"
@ -551,7 +558,7 @@ class _ByteProducer:
request: Request, request: Request,
iterator: Iterator[bytes], iterator: Iterator[bytes],
): ):
self._request = request self._request = request # type: Optional[Request]
self._iterator = iterator self._iterator = iterator
self._paused = False self._paused = False
@ -563,7 +570,7 @@ class _ByteProducer:
""" """
Send a list of bytes as a chunk of a response. Send a list of bytes as a chunk of a response.
""" """
if not data: if not data or not self._request:
return return
self._request.write(b"".join(data)) self._request.write(b"".join(data))

View file

@ -14,8 +14,12 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Union from typing import Optional, Type, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
@ -53,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw): def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = channel.site self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
@ -92,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self): def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self): def get_redacted_uri(self) -> str:
uri = self.uri """Gets the redacted URI associated with the request (or placeholder if the URI
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.uri`.
Returns:
The redacted URI as a string.
"""
uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes): if isinstance(uri, bytes):
uri = self.uri.decode("ascii", errors="replace") uri = uri.decode("ascii", errors="replace")
return redact_uri(uri) return redact_uri(uri)
def get_method(self): def get_method(self) -> str:
"""Gets the method associated with the request (or placeholder if not """Gets the method associated with the request (or placeholder if method
method has yet been received). has not yet been received).
Note: This is necessary as the placeholder value in twisted is str Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`. rather than bytes, so we need to sanitise `self.method`.
Returns: Returns:
str The request method as a string.
""" """
method = self.method method = self.method # type: Union[bytes, str]
if isinstance(method, bytes): if isinstance(method, bytes):
method = self.method.decode("ascii") return self.method.decode("ascii")
return method return method
def render(self, resrc): def render(self, resrc):
@ -333,27 +346,78 @@ class SynapseRequest(Request):
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw): """Request object which honours proxy headers
SynapseRequest.__init__(self, *args, **kw)
""" Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
Add a layer on top of another request that only uses the value of an information from request headers.
X-Forwarded-For header as the result of C{getClientIP}.
""" """
def getClientIP(self): # the client IP and ssl flag, as extracted from the headers.
""" _forwarded_for = None # type: Optional[_XForwardedForAddress]
@return: The client address (the first address) in the value of the _forwarded_https = False # type: bool
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}. def requestReceived(self, command, path, version):
""" # this method is called by the Channel once the full request has been
return ( # received, to dispatch the request to a resource.
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] # We can use it to set the IP address and protocol according to the
.split(b",")[0] # headers.
.strip() self._process_forwarded_headers()
.decode("ascii") return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
# for now, we just use the first x-forwarded-for header. Really, we ought
# to start from the client IP address, and check whether it is trusted; if it
# is, work backwards through the headers until we find an untrusted address.
# see https://github.com/matrix-org/synapse/issues/9471
self._forwarded_for = _XForwardedForAddress(
headers[0].split(b",")[0].strip().decode("ascii")
) )
# if we got an x-forwarded-for header, also look for an x-forwarded-proto header
header = self.getHeader(b"x-forwarded-proto")
if header is not None:
self._forwarded_https = header.lower() == b"https"
else:
# this is done largely for backwards-compatibility so that people that
# haven't set an x-forwarded-proto header don't get a redirect loop.
logger.warning(
"forwarded request lacks an x-forwarded-proto header: assuming https"
)
self._forwarded_https = True
def isSecure(self):
if self._forwarded_https:
return True
return super().isSecure()
def getClientIP(self) -> str:
"""
Return the IP address of the client who submitted this request.
This method is deprecated. Use getClientAddress() instead.
"""
if self._forwarded_for is not None:
return self._forwarded_for.host
return super().getClientIP()
def getClientAddress(self) -> IAddress:
"""
Return the address of the client who submitted this request.
"""
if self._forwarded_for is not None:
return self._forwarded_for
return super().getClientAddress()
@implementer(IAddress)
@attr.s(frozen=True, slots=True)
class _XForwardedForAddress:
host = attr.ib(type=str)
class SynapseSite(Site): class SynapseSite(Site):
""" """
@ -377,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None assert config.http_options is not None
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")

View file

@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint, TCP4ClientEndpoint,
TCP6ClientEndpoint, TCP6ClientEndpoint,
) )
from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.protocol import Factory, Protocol from twisted.internet.protocol import Factory, Protocol
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
try: try:
ip = ip_address(self.host) ip = ip_address(self.host)
if isinstance(ip, IPv4Address): if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) endpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port
) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address): elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else: else:

View file

@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric()) REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func): def runUntilCurrentTimer(reactor, func):
@functools.wraps(func) @functools.wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
now = reactor.seconds() now = reactor.seconds()
@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
try: try:
# Ensure the reactor has all the attributes we expect # Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent reactor.seconds # type: ignore
reactor._newTimedCalls reactor.runUntilCurrent # type: ignore
reactor.threadCallQueue reactor._newTimedCalls # type: ignore
reactor.threadCallQueue # type: ignore
# runUntilCurrent is called when we have pending calls. It is called once # runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling. # per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
# We manually run the GC each reactor tick so that we can get some metrics # We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC, # about time spent doing GC,

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Iterable, Optional, Tuple from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -307,7 +307,7 @@ class ModuleApi:
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events_in_room( def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
) -> defer.Deferred: ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room. """Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should (This is exposed for compatibility with the old SpamCheckerApi. We should

View file

@ -15,11 +15,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
@ -71,9 +72,10 @@ class HttpPusher(Pusher):
self.data = pusher_config.data self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since self.failing_since = pusher_config.failing_since
self.timed_call = None self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
self.data = pusher_config.data self.data = pusher_config.data
if self.data is None: if self.data is None:
@ -292,7 +294,7 @@ class HttpPusher(Pusher):
) )
else: else:
logger.info("Pushkey %s was rejected: removing", pk) logger.info("Pushkey %s was rejected: removing", pk)
await self.hs.remove_pusher(self.app_id, pk, self.user_id) await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id)
return True return True
async def _build_notification_dict( async def _build_notification_dict(

View file

@ -19,12 +19,14 @@ from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge from prometheus_client import Gauge
from synapse.api.errors import Codes, SynapseError
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
from synapse.types import JsonDict, RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
@ -58,7 +60,6 @@ class PusherPool:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.pusher_factory = PusherFactory(hs) self.pusher_factory = PusherFactory(hs)
self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
@ -67,6 +68,16 @@ class PusherPool:
# We shard the handling of push notifications by user ID. # We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances
)
# We can only delete pushers on master.
self._remove_pusher_client = None
if hs.config.worker.worker_app:
self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client(
hs
)
# Record the last stream ID that we were poked about so we can get # Record the last stream ID that we were poked about so we can get
# changes since then. We set this to the current max stream ID on # changes since then. We set this to the current max stream ID on
@ -103,6 +114,11 @@ class PusherPool:
The newly created pusher. The newly created pusher.
""" """
if kind == "email":
email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
if email_owner != user_id:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
# create the pusher setting last_stream_ordering to the current maximum # create the pusher setting last_stream_ordering to the current maximum
@ -175,9 +191,6 @@ class PusherPool:
user_id: user to remove pushers for user_id: user to remove pushers for
access_tokens: access token *ids* to remove pushers for access_tokens: access token *ids* to remove pushers for
""" """
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id): for p in await self.store.get_pushers_by_user_id(user_id):
if p.access_token in tokens: if p.access_token in tokens:
@ -380,6 +393,12 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
# We can only delete pushers on master.
if self._remove_pusher_client:
await self._remove_pusher_client(
app_id=app_id, pushkey=pushkey, user_id=user_id
)
else:
await self.store.delete_pusher_by_app_id_pushkey_user_id( await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id
) )

View file

@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = {
"pysaml2>=4.5.0;python_version>='3.6'", "pysaml2>=4.5.0;python_version>='3.6'",
], ],
"oidc": ["authlib>=0.14.0"], "oidc": ["authlib>=0.14.0"],
# systemd-python is necessary for logging to the systemd journal via
# `systemd.journal.JournalHandler`, as is documented in
# `contrib/systemd/log_config.yaml`.
"systemd": ["systemd-python>=231"], "systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"], "url_preview": ["lxml>=3.5.0"],
"sentry": ["sentry-sdk>=0.7.2"], "sentry": ["sentry-sdk>=0.7.2"],

View file

@ -21,6 +21,7 @@ from synapse.replication.http import (
login, login,
membership, membership,
presence, presence,
push,
register, register,
send_event, send_event,
streams, streams,
@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource):
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
streams.register_servlets(hs, self) streams.register_servlets(hs, self)
account_data.register_servlets(hs, self) account_data.register_servlets(hs, self)
push.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View file

@ -213,8 +213,9 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
args = content["args"] args = content["args"]
args["origin"] = content["origin"]
logger.info("Got %r query", query_type) logger.info("Got %r query from %s", query_type, args["origin"])
result = await self.registry.on_query(query_type, args) result = await self.registry.on_query(query_type, args)

View file

@ -15,9 +15,10 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_left_room from synapse.util.distributor import user_left_room
@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
event_content = content["content"] event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id) logger.info("remote_join: %s into room: %s", user_id, room_id)
@ -147,7 +147,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore
self, request: Request, invite_event_id: str self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -155,7 +155,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
event_content = content["content"] event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester request.requester = requester
# hopefully we're now on the master, so this won't recurse! # hopefully we're now on the master, so this won't recurse!

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
"""Deletes the given pusher.
Request format:
POST /_synapse/replication/remove_pusher/:user_id
{
"app_id": "<some_id>",
"pushkey": "<some_key>"
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id",)
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.pusher_pool = hs.get_pusherpool()
@staticmethod
async def _serialize_payload(app_id, pushkey, user_id):
payload = {
"app_id": app_id,
"pushkey": pushkey,
}
return payload
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
app_id = content["app_id"]
pushkey = content["pushkey"]
await self.pusher_pool.remove_pusher(app_id, pushkey, user_id)
return 200, {}
def register_servlets(hs, http_server):
ReplicationRemovePusherRestServlet(hs).register(http_server)

View file

@ -108,9 +108,7 @@ class ReplicationDataHandler:
# Map from stream to list of deferreds waiting for the stream to # Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position. # arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = ( self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list

View file

@ -325,31 +325,6 @@ class FederationAckCommand(Command):
return "%s %s" % (self.instance_name, self.token) return "%s %s" % (self.instance_name, self.token)
class RemovePusherCommand(Command):
"""Sent by the client to request the master remove the given pusher.
Format::
REMOVE_PUSHER <app_id> <push_key> <user_id>
"""
NAME = "REMOVE_PUSHER"
def __init__(self, app_id, push_key, user_id):
self.user_id = user_id
self.app_id = app_id
self.push_key = push_key
@classmethod
def from_line(cls, line):
app_id, push_key, user_id = line.split(" ", 2)
return cls(app_id, push_key, user_id)
def to_line(self):
return " ".join((self.app_id, self.push_key, self.user_id))
class UserIpCommand(Command): class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client. """Sent periodically when a worker sees activity from a client.
@ -416,7 +391,6 @@ _COMMANDS = (
ReplicateCommand, ReplicateCommand,
UserSyncCommand, UserSyncCommand,
FederationAckCommand, FederationAckCommand,
RemovePusherCommand,
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = (
UserSyncCommand.NAME, UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME, ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME, FederationAckCommand.NAME,
RemovePusherCommand.NAME,
UserIpCommand.NAME, UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,

View file

@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import (
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand, ReplicateCommand,
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
@ -373,23 +372,6 @@ class ReplicationCommandHandler:
if self._federation_sender: if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token) self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
return self._handle_remove_pusher(cmd)
else:
return None
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self._notifier.on_new_replication_data()
def on_USER_IP( def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
@ -684,11 +666,6 @@ class ReplicationCommandHandler:
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
) )
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_user_ip( def send_user_ip(
self, self,
user_id: str, user_id: str,

View file

@ -502,7 +502,7 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed""" """Global or per room account data was changed"""
AccountDataStreamRow = namedtuple( AccountDataStreamRow = namedtuple(
"AccountDataStream", "AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str ("user_id", "room_id", "data_type"), # str # Optional[str] # str
) )

View file

@ -145,7 +145,7 @@
<input type="submit" value="Continue" class="primary-button"> <input type="submit" value="Continue" class="primary-button">
{% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
<section class="idp-pick-details"> <section class="idp-pick-details">
<h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2> <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2>
{% if user_attributes.avatar_url %} {% if user_attributes.avatar_url %}
<label class="idp-detail idp-avatar" for="idp-avatar"> <label class="idp-detail idp-avatar" for="idp-avatar">
<div class="check-row"> <div class="check-row">

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id, device_id): async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
) )
return 200, device return 200, device
async def on_DELETE(self, request, user_id, device_id): async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id) await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {} return 200, {}
async def on_PUT(self, request, user_id, device_id): async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)

View file

@ -14,10 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$") PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, report_id): async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
message = ( message = (
"The report_id parameter must be a string representing a positive integer." "The report_id parameter must be a string representing a positive integer."
) )
try: try:
report_id = int(report_id) resolved_report_id = int(report_id)
except ValueError: except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
if report_id < 0: if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
ret = await self.store.get_event_report(report_id) ret = await self.store.get_event_report(resolved_report_id)
if not ret: if not ret:
raise NotFoundError("Event report not found") raise NotFoundError("Event report not found")

View file

@ -17,7 +17,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer

View file

@ -44,6 +44,48 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ResolveRoomIdMixin:
def __init__(self, hs: "HomeServer"):
self.room_member_handler = hs.get_room_member_handler()
async def resolve_room_id(
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
"""
Resolve a room identifier to a room ID, if necessary.
This also performanes checks to ensure the room ID is of the proper form.
Args:
room_identifier: The room ID or alias.
remote_room_hosts: The potential remote room hosts to use.
Returns:
The resolved room ID.
Raises:
SynapseError if the room ID is of the wrong form.
"""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
(
room_id,
remote_room_hosts,
) = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id, remote_room_hosts
class ShutdownRoomRestServlet(RestServlet): class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking """Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed all future invites and joins to the room. Any local aliases will be repointed
@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
return 200, ret return 200, ret
class JoinRoomAliasServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -362,21 +404,15 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user): if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found") raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier): # Get the room ID from the identifier.
room_id = room_identifier
try: try:
remote_room_hosts = [ remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"] x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]] ] # type: Optional[List[str]]
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): room_id, remote_room_hosts = await self.resolve_room_id(
handler = self.room_member_handler room_identifier, remote_room_hosts
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
fake_requester = create_requester( fake_requester = create_requester(
@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
class MakeRoomAdminRestServlet(RestServlet): class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get power in a room if a local user has power in """Allows a server admin to get power in a room if a local user has power in
a room. Will also invite the user if they're not in the room and it's a a room. Will also invite the user if they're not in the room and it's a
private room. Can specify another user (rather than the admin user) to be private room. Can specify another user (rather than the admin user) to be
@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
async def on_POST(self, request, room_identifier): async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request, allow_empty_body=True) content = parse_json_object_from_request(request, allow_empty_body=True)
# Resolve to a room ID, if necessary. room_id, _ = await self.resolve_room_id(room_identifier)
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
# Which user to grant room admin rights to. # Which user to grant room admin rights to.
user_to_add = content.get("user_id", requester.user.to_string()) user_to_add = content.get("user_id", requester.user.to_string())
@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
return 200, {} return 200, {}
class ForwardExtremitiesRestServlet(RestServlet): class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get or clear forward extremities. """Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server. Clearing does not require restarting the server.
@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str: async def on_DELETE(
"""Resolve to a room ID, if necessary.""" self, request: SynapseRequest, room_identifier: str
if RoomID.is_valid(room_identifier): ) -> Tuple[int, JsonDict]:
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id) deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count} return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier): async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id) extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities} return 200, {"count": len(extremities), "results": extremities}
@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id): async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)

View file

@ -16,7 +16,7 @@ import hashlib
import hmac import hmac
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin, assert_user_is_admin,
) )
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -46,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet): class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
@ -152,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error. otherwise an error.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
@ -164,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -178,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret return 200, ret
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -272,6 +279,8 @@ class UserRestServletV2(RestServlet):
) )
user = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 200, user return 200, user
else: # create user else: # create user
@ -329,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True target_user, requester, body["avatar_url"], True
) )
ret = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 201, ret return 201, user
class UserRegisterServlet(RestServlet): class UserRegisterServlet(RestServlet):
@ -345,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register") PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60 NONCE_TIMEOUT = 60
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self.nonces = {} self.nonces = {} # type: Dict[str, int]
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self):
@ -361,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT: if now - v > self.NONCE_TIMEOUT:
del self.nonces[k] del self.nonces[k]
def on_GET(self, request): def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
""" """
Generate a new nonce. Generate a new nonce.
""" """
@ -371,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce} return 200, {"nonce": nonce}
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces() self._clear_old_nonces()
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -477,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True) client_patterns("/admin" + path_regex, v1=True)
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user auth_user = requester.user
@ -507,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -549,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -583,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id): async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
@ -625,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, target_user_id): async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to """Get request to search user table for specific users according to
search term. search term.
This needs user to have a administrator access in Synapse. This needs user to have a administrator access in Synapse.
@ -681,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -698,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin} return 200, {"admin": is_admin}
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -729,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@ -757,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -798,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -832,8 +856,33 @@ class UserMediaRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
else:
order_by = parse_string(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
)
media, total = await self.store.get_local_media_by_user_paginate( media, total = await self.store.get_local_media_by_user_paginate(
start, limit, user_id start, limit, user_id, order_by, direction
) )
ret = {"media": media, "total": total} ret = {"media": media, "total": total}
@ -865,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -917,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):

View file

@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
super().register(http_server) super().register(http_server)
@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None: ) -> None:
if not self._public_baseurl:
raise SynapseError(400, "SSO requires a valid public_baseurl")
# if this isn't the expected hostname, redirect to the right one, so that we
# get our cookies back.
requested_uri = get_request_uri(request)
baseurl_bytes = self._public_baseurl.encode("utf-8")
if not requested_uri.startswith(baseurl_bytes):
# swap out the incorrect base URL for the right one.
#
# The idea here is to redirect from
# https://foo.bar/whatever/_matrix/...
# to
# https://public.baseurl/_matrix/...
#
i = requested_uri.index(b"/_matrix")
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
logger.info(
"Requested URI %s is not canonical: redirecting to %s",
requested_uri.decode("utf-8", errors="replace"),
new_uri.decode("utf-8", errors="replace"),
)
request.redirect(new_uri)
finish_request(request)
return
client_redirect_url = parse_string( client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None request, "redirectUrl", required=True, encoding=None
) )

View file

@ -18,7 +18,7 @@ import logging
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.constants import ( from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_CATEGORYID_LENGTH,

View file

@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",)) assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
await self.device_message_handler.send_device_message( await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"] requester, message_type, content["messages"]
) )
response = (200, {}) # type: Tuple[int, dict] response = (200, {}) # type: Tuple[int, dict]

View file

@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try: try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2] server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
if isinstance(server_name, bytes): media_id = media_id_bytes.decode("utf8")
server_name = server_name.decode("utf-8")
media_id = media_id.decode("utf8")
file_name = None file_name = None
if len(request.postpath) > 2: if len(postpath) > 2:
try: try:
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
return server_name, media_id, file_name return server_name, media_id, file_name

View file

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json

View file

@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean from synapse.http.servlet import parse_boolean

View file

@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import ( from synapse.api.errors import (
FederationDeniedError, FederationDeniedError,
@ -509,7 +509,7 @@ class MediaRepository:
t_height: int, t_height: int,
t_method: str, t_method: str,
t_type: str, t_type: str,
url_cache: str, url_cache: Optional[str],
) -> Optional[str]: ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache( input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache) FileInfo(None, media_id, url_cache=url_cache)

View file

@ -244,7 +244,7 @@ class MediaStorage:
await consumer.wait() await consumer.wait()
return local_path return local_path
raise Exception("file could not be found") raise NotFoundError()
def _file_info_to_path(self, file_info: FileInfo) -> str: def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path. """Converts file_info into a relative path.

View file

@ -29,7 +29,7 @@ from urllib import parse as urlparse
import attr import attr
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -137,9 +137,8 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True}, treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist, ip_blacklist=hs.config.url_preview_ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
user_agent=f"{hs.version_string} UrlPreviewBot" user_agent=f"{hs.version_string} UrlPreviewBot"
use_proxy=True,
) )
self.media_repo = media_repo self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path self.primary_base_path = media_repo.primary_base_path

View file

@ -18,7 +18,7 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource):
m_type, m_type,
thumbnail_infos, thumbnail_infos,
media_id, media_id,
media_id,
url_cache=media_info["url_cache"], url_cache=media_info["url_cache"],
server_name=None, server_name=None,
) )
@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource):
method, method,
m_type, m_type,
thumbnail_infos, thumbnail_infos,
media_id,
media_info["filesystem_id"], media_info["filesystem_id"],
url_cache=None, url_cache=None,
server_name=server_name, server_name=server_name,
@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
thumbnail_infos: List[Dict[str, Any]], thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str, file_id: str,
url_cache: Optional[str] = None, url_cache: Optional[str] = None,
server_name: Optional[str] = None, server_name: Optional[str] = None,
@ -317,8 +320,59 @@ class ThumbnailResource(DirectServeJsonResource):
return return
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder( await respond_with_responder(
request, responder, file_info.thumbnail_type, file_info.thumbnail_length request,
responder,
file_info.thumbnail_type,
file_info.thumbnail_length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail_width,
t_height=file_info.thumbnail_height,
t_method=file_info.thumbnail_method,
t_type=file_info.thumbnail_type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail_width,
t_height=file_info.thumbnail_height,
t_method=file_info.thumbnail_method,
t_type=file_info.thumbnail_type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request,
responder,
file_info.thumbnail_type,
file_info.thumbnail_length,
) )
else: else:
logger.info("Failed to find any generated thumbnails") logger.info("Failed to find any generated thumbnails")

View file

@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import IO, TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"): if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") content_type_headers = headers.getRawHeaders(b"Content-Type")
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else: else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
try: try:
content = request.content # type: IO # type: ignore
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user media_type, upload_name, content, content_length, requester.user
) )
except SpamMediaException: except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of # For uploading of media we want to respond with a 400, instead of

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour

View file

@ -16,8 +16,8 @@
import logging import logging
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -24,7 +24,6 @@
import abc import abc
import functools import functools
import logging import logging
import os
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -39,6 +38,7 @@ from typing import (
import twisted.internet.base import twisted.internet.base
import twisted.internet.tcp import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self.start_time = None # type: Optional[int] self.start_time = None # type: Optional[int]
self._instance_id = random_string(5) self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master" self._instance_name = config.worker.instance_name
self.version_string = version_string self.version_string = version_string
@ -370,11 +370,7 @@ class HomeServer(metaclass=abc.ABCMeta):
""" """
An HTTP client that uses configured HTTP(S) proxies. An HTTP client that uses configured HTTP(S) proxies.
""" """
return SimpleHttpClient( return SimpleHttpClient(self, use_proxy=True)
self,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
@cache_in_self @cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
@ -386,8 +382,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self, self,
ip_whitelist=self.config.ip_range_whitelist, ip_whitelist=self.config.ip_range_whitelist,
ip_blacklist=self.config.ip_range_blacklist, ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"), use_proxy=True,
https_proxy=os.getenvb(b"HTTPS_PROXY"),
) )
@cache_in_self @cache_in_self
@ -409,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomShutdownHandler(self) return RoomShutdownHandler(self)
@cache_in_self @cache_in_self
def get_sendmail(self) -> sendmail: def get_sendmail(self) -> Callable[..., defer.Deferred]:
return sendmail return sendmail
@cache_in_self @cache_in_self
@ -758,12 +753,6 @@ class HomeServer(metaclass=abc.ABCMeta):
reconnect=True, reconnect=True,
) )
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
def should_send_federation(self) -> bool: def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?" "Should this server be sending federation traffic directly?"
return self.config.send_federation and ( return self.config.send_federation
not self.config.worker_app
or self.config.worker_app == "synapse.app.federation_sender"
)

View file

@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection from synapse.types import Collection
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0 _TXN_ID = 0
def __init__( def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine self,
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
): ):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -420,16 +422,6 @@ class DatabasePool:
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
engine, get_chain_id_txn, "event_auth_chain_id"
)
def is_running(self) -> bool: def is_running(self) -> bool:
"""Is the database pool currently running""" """Is the database pool currently running"""
return self._db_pool.running return self._db_pool.running

View file

@ -79,7 +79,7 @@ class Databases:
# If we're on a process that can persist events also # If we're on a process that can persist events also
# instantiate a `PersistEventsStore` # instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events: if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main) persist_events = PersistEventsStore(hs, database, main, db_conn)
if "state" in database_config.databases: if "state" in database_config.databases:
logger.info( logger.info(

View file

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore from .account_data import AccountDataStore
@ -264,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
async def get_users(self) -> List[Dict[str, Any]]: async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table. """Function to retrieve a list of users in users table.
Returns: Returns:
@ -292,7 +292,7 @@ class DataStore(
name: Optional[str] = None, name: Optional[str] = None,
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
@ -353,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn "get_users_paginate_txn", get_users_paginate_txn
) )
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with """Function to search users list for one or more users with
the matched term. the matched term.

View file

@ -42,7 +42,9 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.iterutils import batch_iter, sorted_topologically
@ -90,7 +92,11 @@ class PersistEventsStore:
""" """
def __init__( def __init__(
self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" self,
hs: "HomeServer",
db: DatabasePool,
main_data_store: "DataStore",
db_conn: Connection,
): ):
self.hs = hs self.hs = hs
self.db_pool = db self.db_pool = db
@ -474,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index( self._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -484,6 +491,7 @@ class PersistEventsStore:
cls, cls,
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -630,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids( new_chain_tuples = cls._allocate_chain_ids(
txn, txn,
db_pool, db_pool,
event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -768,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids( def _allocate_chain_ids(
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -880,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs. # Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids) txn, len(unallocated_chain_ids)
) )

View file

@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
) )
if not has_event_auth: if not has_event_auth:
for auth_id in event.auth_event_ids(): # Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
auth_events.append( auth_events.append(
{ {
"room_id": event.room_id, "room_id": event.room_id,
@ -917,6 +919,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,

View file

@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
database.engine,
get_chain_id_txn,
"event_auth_chain_id",
table="event_auth_chains",
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token) self._stream_id_gen.advance(instance_name, token)

View file

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
) )
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
get_local_media_by_user_paginate
"""
MEDIA_ID = "media_id"
UPLOAD_NAME = "upload_name"
CREATED_TS = "created_ts"
LAST_ACCESS_TS = "last_access_ts"
MEDIA_LENGTH = "media_length"
MEDIA_TYPE = "media_type"
QUARANTINED_BY = "quarantined_by"
SAFE_FROM_QUARANTINE = "safe_from_quarantine"
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
self, start: int, limit: int, user_id: str self,
start: int,
limit: int,
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media
which an user_id has uploaded which an user_id has uploaded
@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: offset in the list start: offset in the list
limit: maximum amount of media_ids to retrieve limit: maximum amount of media_ids to retrieve
user_id: fully-qualified user id user_id: fully-qualified user id
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns: Returns:
A paginated list of all metadata of user's media, A paginated list of all metadata of user's media,
plus the total count of all the user's media plus the total count of all the user's media
@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(txn): def get_local_media_by_user_paginate_txn(txn):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
if direction == "b":
order = "DESC"
else:
order = "ASC"
args = [user_id] args = [user_id]
sql = """ sql = """
SELECT COUNT(*) as total_media SELECT COUNT(*) as total_media
@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine" "safe_from_quarantine"
FROM local_media_repository FROM local_media_repository
WHERE user_id = ? WHERE user_id = ?
ORDER BY created_ts DESC, media_id DESC ORDER BY {order_by_column} {order}, media_id ASC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """.format(
order_by_column=order_by_column,
order=order,
)
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)
@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
await self.db_pool.simple_insert( await self.db_pool.simple_upsert(
"local_media_repository_thumbnails", table="local_media_repository_thumbnails",
{ keyvalues={
"media_id": media_id, "media_id": media_id,
"thumbnail_width": thumbnail_width, "thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height, "thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method, "thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type, "thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
}, },
values={"thumbnail_length": thumbnail_length},
desc="store_local_thumbnail", desc="store_local_thumbnail",
) )
@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
await self.db_pool.simple_insert( await self.db_pool.simple_upsert(
"remote_media_cache_thumbnails", table="remote_media_cache_thumbnails",
{ keyvalues={
"media_origin": origin, "media_origin": origin,
"media_id": media_id, "media_id": media_id,
"thumbnail_width": thumbnail_width, "thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height, "thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method, "thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type, "thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id,
}, },
values={"thumbnail_length": thumbnail_length},
insertion_values={"filesystem_id": filesystem_id},
desc="store_remote_media_thumbnail", desc="store_remote_media_thumbnail",
) )

View file

@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
async def purge_history( async def purge_history(
self, room_id: str, token: str, delete_local_events: bool self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]: ) -> Set[int]:
"""Deletes room history before a certain point """Deletes room history before a certain point.
Note that only a single purge can occur at once, this is guaranteed via
a higher level (in the PaginationHandler).
Args: Args:
room_id: room_id:
@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
delete_local_events, delete_local_events,
) )
def _purge_history_txn(self, txn, room_id, token, delete_local_events): def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
) -> Set[int]:
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
# event_backward_extremities # event_backward_extremities
@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
if max_depth < token.topological: if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not # otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties) # having any backwards extremities)
raise SynapseError( raise SynapseError(
400, "topological_ordering is greater than forward extremeties" 400, "topological_ordering is greater than forward extremeties"
) )
@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
logger.info("[purge] Finding new backward extremities") logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding # We calculate the new entries for the backward extremities by finding
# events to be purged that are pointed to by events we're not going to # events to be purged that are pointed to by events we're not going to
# purge. # purge.
txn.execute( txn.execute(
@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"purge_room", self._purge_room_txn, room_id "purge_room", self._purge_room_txn, room_id
) )
def _purge_room_txn(self, txn, room_id): def _purge_room_txn(self, txn, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before
# we delete that information. # we delete that information.
txn.execute( txn.execute(
@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
state_groups = [row[0] for row in txn] state_groups = [row[0] for row in txn]
# Get all the auth chains that are referenced by events that are to be
# deleted.
txn.execute(
"""
SELECT chain_id, sequence_number FROM events
LEFT JOIN event_auth_chains USING (event_id)
WHERE room_id = ?
""",
(room_id,),
)
referenced_chain_id_tuples = list(txn)
logger.info("[purge] removing events from event_auth_chain_links")
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
(origin_chain_id = ? AND origin_sequence_number = ?) OR
(target_chain_id = ? AND target_sequence_number = ?)
""",
(
(chain_id, seq_num, chain_id, seq_num)
for (chain_id, seq_num) in referenced_chain_id_tuples
),
)
# Now we delete tables which lack an index on room_id but have one on event_id # Now we delete tables which lack an index on room_id but have one on event_id
for table in ( for table in (
"event_auth", "event_auth",
@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_reference_hashes", "event_reference_hashes",
"event_relations", "event_relations",
"event_to_state_groups", "event_to_state_groups",
"event_auth_chains",
"event_auth_chain_to_calculate",
"redactions", "redactions",
"rejections", "rejections",
"state_events", "state_events",

View file

@ -39,6 +39,16 @@ class PusherWorkerStore(SQLBaseStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
self._remove_deactivated_pushers,
)
self.db_pool.updates.register_background_update_handler(
"remove_stale_pushers",
self._remove_stale_pushers,
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table """JSON-decode the data in the rows returned from the `pushers` table
@ -284,6 +294,101 @@ class PusherWorkerStore(SQLBaseStore):
lock=False, lock=False,
) )
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
"""A background update that deletes all pushers for deactivated users.
Note that we don't proacively tell the pusherpool that we've deleted
these (just because its a bit off a faff to do from here), but they will
get cleaned up at the next restart
"""
last_user = progress.get("last_user", "")
def _delete_pushers(txn) -> int:
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
ORDER BY name ASC
LIMIT ?
"""
txn.execute(sql, (1, last_user, batch_size))
users = [row[0] for row in txn]
self.db_pool.simple_delete_many_txn(
txn,
table="pushers",
column="user_name",
iterable=users,
keyvalues={},
)
if users:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_deactivated_pushers", {"last_user": users[-1]}
)
return len(users)
number_deleted = await self.db_pool.runInteraction(
"_remove_deactivated_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update(
"remove_deactivated_pushers"
)
return number_deleted
async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
"""A background update that deletes all pushers for logged out devices.
Note that we don't proacively tell the pusherpool that we've deleted
these (just because its a bit off a faff to do from here), but they will
get cleaned up at the next restart
"""
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int:
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
WHERE p.id > ?
ORDER BY p.id ASC
LIMIT ?
"""
txn.execute(sql, (last_pusher, batch_size))
pushers = [(row[0], row[1]) for row in txn]
self.db_pool.simple_delete_many_txn(
txn,
table="pushers",
column="id",
iterable=(pusher_id for pusher_id, token in pushers if token is None),
keyvalues={},
)
if pushers:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
)
return len(pushers)
number_deleted = await self.db_pool.runInteraction(
"_remove_stale_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update("remove_stale_pushers")
return number_deleted
class PusherStore(PusherWorkerStore): class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
@ -373,3 +478,46 @@ class PusherStore(PusherWorkerStore):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id "delete_pusher", delete_pusher_txn, stream_id
) )
async def delete_all_pushers_for_user(self, user_id: str) -> None:
"""Delete all pushers associated with an account."""
# We want to generate a row in `deleted_pushers` for each pusher we're
# deleting, so we fetch the list now so we can generate the appropriate
# number of stream IDs.
#
# Note: technically there could be a race here between adding/deleting
# pushers, but a) the worst case if we don't stop a pusher until the
# next restart and b) this is only called when we're deactivating an
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids):
self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
self.db_pool.simple_delete_txn(
txn,
table="pushers",
keyvalues={"user_name": user_id},
)
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
values=[
{
"stream_id": stream_id,
"app_id": pusher.app_id,
"pushkey": pusher.pushkey,
"user_id": user_id,
}
for stream_id, pusher in zip(stream_ids, pushers)
],
)
async with self._pushers_id_gen.get_next_mult(len(pushers)) as stream_ids:
await self.db_pool.runInteraction(
"delete_all_pushers_for_user", delete_pushers_txn, stream_ids
)

View file

@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries. # expensive if there are many entries.
self._user_id_seq = build_sequence_generator( self._user_id_seq = build_sequence_generator(
db_conn,
database.engine, database.engine,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
"user_id_seq", "user_id_seq",
table=None,
id_column=None,
) )
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()

View file

@ -1,4 +1,4 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C /* Copyright 2021 The Matrix.org Foundation C.I.C
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,5 +13,8 @@
* limitations under the License. * limitations under the License.
*/ */
-- We may not have deleted all pushers for deactivated accounts, so we set up a
-- background job to delete them.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5828, 'rejected_events_metadata', '{}'); (5908, 'remove_deactivated_pushers', '{}');

View file

@ -0,0 +1,20 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Delete all pushers associated with deleted devices. This is to clear up after
-- a bug where they weren't correctly deleted when using workers.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5908, 'remove_stale_pushers', '{}');

Some files were not shown because too many files have changed in this diff Show more