Merge remote-tracking branch 'upstream/release-v1.50'

This commit is contained in:
Tulir Asokan 2022-01-07 14:21:32 +02:00
commit e9caf56ca0
205 changed files with 4905 additions and 2749 deletions

View File

@ -76,7 +76,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
database: ["sqlite"]
toxenv: ["py"]
include:
@ -85,9 +85,9 @@ jobs:
toxenv: "py-noextras"
# Oldest Python with PostgreSQL
- python-version: "3.6"
- python-version: "3.7"
database: "postgres"
postgres-version: "9.6"
postgres-version: "10"
toxenv: "py"
# Newest Python with newest PostgreSQL
@ -167,7 +167,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["pypy-3.6"]
python-version: ["pypy-3.7"]
steps:
- uses: actions/checkout@v2
@ -291,8 +291,8 @@ jobs:
strategy:
matrix:
include:
- python-version: "3.6"
postgres-version: "9.6"
- python-version: "3.7"
postgres-version: "10"
- python-version: "3.10"
postgres-version: "14"

View File

@ -1,3 +1,92 @@
Synapse 1.50.0rc1 (2022-01-05)
==============================
Please note that we now only support Python 3.7+ and PostgreSQL 10+ (if applicable), because Python 3.6 and PostgreSQL 9.6 have reached end-of-life.
Features
--------
- Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). ([\#11378](https://github.com/matrix-org/synapse/issues/11378))
- Add experimental support for part of [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): allowing application services to masquerade as specific devices. ([\#11538](https://github.com/matrix-org/synapse/issues/11538))
- Add admin API to get users' account data. ([\#11664](https://github.com/matrix-org/synapse/issues/11664))
- Include the room topic in the stripped state included with invites and knocking. ([\#11666](https://github.com/matrix-org/synapse/issues/11666))
- Send and handle cross-signing messages using the stable prefix. ([\#10520](https://github.com/matrix-org/synapse/issues/10520))
- Support unprefixed versions of fallback key property names. ([\#11541](https://github.com/matrix-org/synapse/issues/11541))
Bugfixes
--------
- Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. ([\#11516](https://github.com/matrix-org/synapse/issues/11516))
- Fix a long-standing bug which could cause `AssertionError`s to be written to the log when Synapse was restarted after purging events from the database. ([\#11536](https://github.com/matrix-org/synapse/issues/11536), [\#11642](https://github.com/matrix-org/synapse/issues/11642))
- Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. ([\#11547](https://github.com/matrix-org/synapse/issues/11547))
- Fix a long-standing bug where responses included bundled aggregations when they should not, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11592](https://github.com/matrix-org/synapse/issues/11592), [\#11623](https://github.com/matrix-org/synapse/issues/11623))
- Fix a long-standing bug that some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. ([\#11602](https://github.com/matrix-org/synapse/issues/11602))
- Fix a bug introduced in Synapse 1.19.3 which could sometimes cause `AssertionError`s when backfilling rooms over federation. ([\#11632](https://github.com/matrix-org/synapse/issues/11632))
Improved Documentation
----------------------
- Update Synapse install command for FreeBSD as the package is now prefixed with `py38`. Contributed by @itchychips. ([\#11267](https://github.com/matrix-org/synapse/issues/11267))
- Document the usage of refresh tokens. ([\#11427](https://github.com/matrix-org/synapse/issues/11427))
- Add details for how to configure a TURN server when behind a NAT. Contibuted by @AndrewFerr. ([\#11553](https://github.com/matrix-org/synapse/issues/11553))
- Add references for using Postgres to the Docker documentation. ([\#11640](https://github.com/matrix-org/synapse/issues/11640))
- Fix the documentation link in newly-generated configuration files. ([\#11678](https://github.com/matrix-org/synapse/issues/11678))
- Correct the documentation for `nginx` to use a case-sensitive url pattern. Fixes an error introduced in v1.21.0. ([\#11680](https://github.com/matrix-org/synapse/issues/11680))
- Clarify SSO mapping provider documentation by writing `def` or `async def` before the names of methods, as appropriate. ([\#11681](https://github.com/matrix-org/synapse/issues/11681))
Deprecations and Removals
-------------------------
- Replace `mock` package by its standard library version. ([\#11588](https://github.com/matrix-org/synapse/issues/11588))
Internal Changes
----------------
- Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#11243](https://github.com/matrix-org/synapse/issues/11243))
- A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. ([\#11331](https://github.com/matrix-org/synapse/issues/11331))
- Add type hints to `synapse.appservice`. ([\#11360](https://github.com/matrix-org/synapse/issues/11360))
- Add missing type hints to `synapse.config` module. ([\#11480](https://github.com/matrix-org/synapse/issues/11480))
- Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. ([\#11487](https://github.com/matrix-org/synapse/issues/11487))
- Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. ([\#11503](https://github.com/matrix-org/synapse/issues/11503))
- Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. ([\#11505](https://github.com/matrix-org/synapse/issues/11505), [\#11687](https://github.com/matrix-org/synapse/issues/11687))
- Use `HTTPStatus` constants in place of literals in `tests.rest.client.test_auth`. ([\#11520](https://github.com/matrix-org/synapse/issues/11520))
- Add a receipt types constant for `m.read`. ([\#11531](https://github.com/matrix-org/synapse/issues/11531))
- Clean up `synapse.rest.admin`. ([\#11535](https://github.com/matrix-org/synapse/issues/11535))
- Add missing `errcode` to `parse_string` and `parse_boolean`. ([\#11542](https://github.com/matrix-org/synapse/issues/11542))
- Use `HTTPStatus` constants in place of literals in `synapse.http`. ([\#11543](https://github.com/matrix-org/synapse/issues/11543))
- Add missing type hints to storage classes. ([\#11546](https://github.com/matrix-org/synapse/issues/11546), [\#11549](https://github.com/matrix-org/synapse/issues/11549), [\#11551](https://github.com/matrix-org/synapse/issues/11551), [\#11555](https://github.com/matrix-org/synapse/issues/11555), [\#11575](https://github.com/matrix-org/synapse/issues/11575), [\#11589](https://github.com/matrix-org/synapse/issues/11589), [\#11594](https://github.com/matrix-org/synapse/issues/11594), [\#11652](https://github.com/matrix-org/synapse/issues/11652), [\#11653](https://github.com/matrix-org/synapse/issues/11653), [\#11654](https://github.com/matrix-org/synapse/issues/11654), [\#11657](https://github.com/matrix-org/synapse/issues/11657))
- Fix an inaccurate and misleading comment in the `/sync` code. ([\#11550](https://github.com/matrix-org/synapse/issues/11550))
- Add missing type hints to `synapse.logging.context`. ([\#11556](https://github.com/matrix-org/synapse/issues/11556))
- Stop populating unused database column `state_events.prev_state`. ([\#11558](https://github.com/matrix-org/synapse/issues/11558))
- Minor efficiency improvements in event persistence. ([\#11560](https://github.com/matrix-org/synapse/issues/11560))
- Add some safety checks that storage functions are used correctly. ([\#11564](https://github.com/matrix-org/synapse/issues/11564), [\#11580](https://github.com/matrix-org/synapse/issues/11580))
- Make `get_device` return `None` if the device doesn't exist rather than raising an exception. ([\#11565](https://github.com/matrix-org/synapse/issues/11565))
- Split the HTML parsing code from the URL preview resource code. ([\#11566](https://github.com/matrix-org/synapse/issues/11566))
- Remove redundant `COALESCE()`s around `COUNT()`s in database queries. ([\#11570](https://github.com/matrix-org/synapse/issues/11570))
- Add missing type hints to `synapse.http`. ([\#11571](https://github.com/matrix-org/synapse/issues/11571))
- Add [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) and [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) to `/versions` -> `unstable_features` to detect server support. ([\#11582](https://github.com/matrix-org/synapse/issues/11582))
- Add type hints to `synapse/tests/rest/admin`. ([\#11590](https://github.com/matrix-org/synapse/issues/11590))
- Drop end-of-life Python 3.6 and Postgres 9.6 from CI. ([\#11595](https://github.com/matrix-org/synapse/issues/11595))
- Update black version and run it on all the files. ([\#11596](https://github.com/matrix-org/synapse/issues/11596))
- Add opentracing type stubs and fix associated mypy errors. ([\#11603](https://github.com/matrix-org/synapse/issues/11603), [\#11622](https://github.com/matrix-org/synapse/issues/11622))
- Improve OpenTracing support for requests which use a `ResponseCache`. ([\#11607](https://github.com/matrix-org/synapse/issues/11607))
- Improve OpenTracing support for incoming HTTP requests. ([\#11618](https://github.com/matrix-org/synapse/issues/11618))
- A number of improvements to opentracing support. ([\#11619](https://github.com/matrix-org/synapse/issues/11619))
- Drop support for Python 3.6 and Ubuntu 18.04. ([\#11633](https://github.com/matrix-org/synapse/issues/11633))
- Refactor the way that the `outlier` flag is set on events received over federation. ([\#11634](https://github.com/matrix-org/synapse/issues/11634))
- Improve the error messages from `get_create_event_for_room`. ([\#11638](https://github.com/matrix-org/synapse/issues/11638))
- Remove redundant `get_current_events_token` method. ([\#11643](https://github.com/matrix-org/synapse/issues/11643))
- Convert `namedtuples` to `attrs`. ([\#11665](https://github.com/matrix-org/synapse/issues/11665), [\#11574](https://github.com/matrix-org/synapse/issues/11574))
- Update the `/capabilities` response to include whether support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) is available. ([\#11690](https://github.com/matrix-org/synapse/issues/11690))
- Send the `Accept` header in HTTP requests made using `SimpleHttpClient.get_json`. ([\#11677](https://github.com/matrix-org/synapse/issues/11677))
- Work around Mjolnir compatibility issue by adding an import for `glob_to_regex` in `synapse.util`, where it moved from. ([\#11696](https://github.com/matrix-org/synapse/issues/11696))
Synapse 1.49.2 (2021-12-21)
===========================

View File

@ -14,6 +14,7 @@ services:
# failure
restart: unless-stopped
# See the readme for a full documentation of the environment settings
# NOTE: You must edit homeserver.yaml to use postgres, it defaults to sqlite
environment:
- SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
volumes:

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.50.0~rc1) stable; urgency=medium
* New synapse release 1.50.0~rc1.
-- Synapse Packaging team <packages@matrix.org> Wed, 05 Jan 2022 12:36:17 +0000
matrix-synapse-py3 (1.49.2) stable; urgency=medium
* New synapse release 1.49.2.

View File

@ -16,7 +16,7 @@ ARG distro=""
### Stage 0: build a dh-virtualenv
###
# This is only really needed on bionic and focal, since other distributions we
# This is only really needed on focal, since other distributions we
# care about have a recent version of dh-virtualenv by default. Unfortunately,
# it looks like focal is going to be with us for a while.
#
@ -36,9 +36,8 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget
# fetch and unpack the package
# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11)
RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/refs/tags/1.2.2.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
# install its build deps. We do another apt-cache-update here, because we might
@ -86,12 +85,12 @@ RUN apt-get update -qq -o Acquire::Languages=none \
libpq-dev \
xmlsec1
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
COPY --from=builder /dh-virtualenv_1.2.2-1_all.deb /
# install dhvirtualenv. Update the apt cache again first, in case we got a
# cached cache from docker the first time.
RUN apt-get update -qq -o Acquire::Languages=none \
&& apt-get install -yq /dh-virtualenv_1.2~dev-1_all.deb
&& apt-get install -yq /dh-virtualenv_1.2.2-1_all.deb
WORKDIR /synapse/source
ENTRYPOINT ["bash","/synapse/source/docker/build_debian.sh"]

View File

@ -68,6 +68,10 @@ The following environment variables are supported in `generate` mode:
directories. If unset, and no user is set via `docker run --user`, defaults
to `991`, `991`.
## Postgres
By default the config will use SQLite. See the [docs on using Postgres](https://github.com/matrix-org/synapse/blob/develop/docs/postgres.md) for more info on how to use Postgres. Until this section is improved [this issue](https://github.com/matrix-org/synapse/issues/8304) may provide useful information.
## Running synapse
Once you have a valid configuration file, you can start synapse as follows:

View File

@ -30,6 +30,7 @@
- [SSO Mapping Providers](sso_mapping_providers.md)
- [Password Auth Providers](password_auth_providers.md)
- [JSON Web Tokens](jwt.md)
- [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md)
- [Registration Captcha](CAPTCHA_SETUP.md)
- [Application Services](application_services.md)
- [Server Notices](server_notices.md)

View File

@ -480,6 +480,81 @@ The following fields are returned in the JSON response body:
- `joined_rooms` - An array of `room_id`.
- `total` - Number of rooms.
## Account Data
Gets information about account data for a specific `user_id`.
The API is:
```
GET /_synapse/admin/v1/users/<user_id>/accountdata
```
A response body like the following is returned:
```json
{
"account_data": {
"global": {
"m.secret_storage.key.LmIGHTg5W": {
"algorithm": "m.secret_storage.v1.aes-hmac-sha2",
"iv": "fwjNZatxg==",
"mac": "eWh9kNnLWZUNOgnc="
},
"im.vector.hide_profile": {
"hide_profile": true
},
"org.matrix.preview_urls": {
"disable": false
},
"im.vector.riot.breadcrumb_rooms": {
"rooms": [
"!LxcBDAsDUVAfJDEo:matrix.org",
"!MAhRxqasbItjOqxu:matrix.org"
]
},
"m.accepted_terms": {
"accepted": [
"https://example.org/somewhere/privacy-1.2-en.html",
"https://example.org/somewhere/terms-2.0-en.html"
]
},
"im.vector.setting.breadcrumbs": {
"recent_rooms": [
"!MAhRxqasbItqxuEt:matrix.org",
"!ZtSaPCawyWtxiImy:matrix.org"
]
}
},
"rooms": {
"!GUdfZSHUJibpiVqHYd:matrix.org": {
"m.fully_read": {
"event_id": "$156334540fYIhZ:matrix.org"
}
},
"!tOZwOOiqwCYQkLhV:matrix.org": {
"m.fully_read": {
"event_id": "$xjsIyp4_NaVl2yPvIZs_k1Jl8tsC_Sp23wjqXPno"
}
}
}
}
}
```
**Parameters**
The following parameters should be set in the URL:
- `user_id` - fully qualified: for example, `@user:server.com`.
**Response**
The following fields are returned in the JSON response body:
- `account_data` - A map containing the account data for the user
- `global` - A map containing the global account data for the user
- `rooms` - A map containing the account data per room for the user
## User media
### List media uploaded by a user

View File

@ -63,7 +63,7 @@ server {
server_name matrix.example.com;
location ~* ^(\/_matrix|\/_synapse\/client) {
location ~ ^(/_matrix|/_synapse/client) {
# note: do not add a path (even a single /) after the port in `proxy_pass`,
# otherwise nginx will canonicalise the URI and cause signature verification
# errors.

View File

@ -37,7 +37,7 @@
# Server admins can expand Synapse's functionality with external modules.
#
# See https://matrix-org.github.io/synapse/latest/modules.html for more
# See https://matrix-org.github.io/synapse/latest/modules/index.html for more
# documentation on how to configure or create custom modules for Synapse.
#
modules:
@ -1488,6 +1488,7 @@ room_prejoin_state:
# - m.room.encryption
# - m.room.name
# - m.room.create
# - m.room.topic
#
# Uncomment the following to disable these defaults (so that only the event
# types listed in 'additional_event_types' are shared). Defaults to 'false'.

View File

@ -164,7 +164,7 @@ xbps-install -S synapse
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- Packages: `pkg install py37-matrix-synapse`
- Packages: `pkg install py38-matrix-synapse`
#### OpenBSD

View File

@ -49,12 +49,12 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
* `def __init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `parse_config(config)`
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
@ -63,13 +63,13 @@ A custom mapping provider must specify the following methods:
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_remote_user_id(self, userinfo)`
* `def get_remote_user_id(self, userinfo)`
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- This method must return a string, which is the unique, immutable identifier
for the user. Commonly the `sub` claim of the response.
* `map_user_attributes(self, userinfo, token, failures)`
* `async def map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@ -91,7 +91,7 @@ A custom mapping provider must specify the following methods:
during a user's first login. Once a localpart has been associated with a
remote user ID (see `get_remote_user_id`) it cannot be updated.
- `displayname`: An optional string, the display name for the user.
* `get_extra_attributes(self, userinfo, token)`
* `async def get_extra_attributes(self, userinfo, token)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@ -125,15 +125,15 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config, module_api)`
* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
* `def parse_config(config)`
- **This method should have the `@staticmethod` decoration.**
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml_config.user_mapping_provider.config` homeserver config option.
@ -141,15 +141,15 @@ A custom mapping provider must specify the following methods:
any option values they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_saml_attributes(config)`
- This method should have the `@staticmethod` decoration.
* `def get_saml_attributes(config)`
- **This method should have the `@staticmethod` decoration.**
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the SAML auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
* `get_remote_user_id(self, saml_response, client_redirect_url)`
* `def get_remote_user_id(self, saml_response, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
@ -157,7 +157,7 @@ A custom mapping provider must specify the following methods:
redirected to.
- This method must return a string, which is the unique, immutable identifier
for the user. Commonly the `uid` claim of the response.
* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
* `def saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.

View File

@ -15,8 +15,8 @@ The following sections describe how to install [coturn](<https://github.com/cotu
For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint with a public IP.
Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues
and to often not work.
Hosting TURN behind NAT requires port forwaring and for the NAT gateway to have a public IP.
However, even with appropriate configuration, NAT is known to cause issues and to often not work.
## `coturn` setup
@ -103,7 +103,23 @@ This will install and start a systemd service called `coturn`.
denied-peer-ip=192.168.0.0-192.168.255.255
denied-peer-ip=172.16.0.0-172.31.255.255
# recommended additional local peers to block, to mitigate external access to internal services.
# https://www.rtcsec.com/article/slack-webrtc-turn-compromise-and-bug-bounty/#how-to-fix-an-open-turn-relay-to-address-this-vulnerability
no-multicast-peers
denied-peer-ip=0.0.0.0-0.255.255.255
denied-peer-ip=100.64.0.0-100.127.255.255
denied-peer-ip=127.0.0.0-127.255.255.255
denied-peer-ip=169.254.0.0-169.254.255.255
denied-peer-ip=192.0.0.0-192.0.0.255
denied-peer-ip=192.0.2.0-192.0.2.255
denied-peer-ip=192.88.99.0-192.88.99.255
denied-peer-ip=198.18.0.0-198.19.255.255
denied-peer-ip=198.51.100.0-198.51.100.255
denied-peer-ip=203.0.113.0-203.0.113.255
denied-peer-ip=240.0.0.0-255.255.255.255
# special case the turn server itself so that client->TURN->TURN->client flows work
# this should be one of the turn server's listening IPs
allowed-peer-ip=10.0.0.1
# consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS.
@ -123,7 +139,7 @@ This will install and start a systemd service called `coturn`.
pkey=/path/to/privkey.pem
```
In this case, replace the `turn:` schemes in the `turn_uri` settings below
In this case, replace the `turn:` schemes in the `turn_uris` settings below
with `turns:`.
We recommend that you only try to set up TLS/DTLS once you have set up a
@ -134,21 +150,33 @@ This will install and start a systemd service called `coturn`.
traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
for the UDP relay.)
1. We do not recommend running a TURN server behind NAT, and are not aware of
anyone doing so successfully.
If you want to try it anyway, you will at least need to tell coturn its
external IP address:
1. If your TURN server is behind NAT, the NAT gateway must have an external,
publicly-reachable IP address. You must configure coturn to advertise that
address to connecting clients:
```
external-ip=192.88.99.1
external-ip=EXTERNAL_NAT_IPv4_ADDRESS
```
... and your NAT gateway must forward all of the relayed ports directly
(eg, port 56789 on the external IP must be always be forwarded to port
56789 on the internal IP).
You may optionally limit the TURN server to listen only on the local
address that is mapped by NAT to the external address:
If you get this working, let us know!
```
listening-ip=INTERNAL_TURNSERVER_IPv4_ADDRESS
```
If your NAT gateway is reachable over both IPv4 and IPv6, you may
configure coturn to advertise each available address:
```
external-ip=EXTERNAL_NAT_IPv4_ADDRESS
external-ip=EXTERNAL_NAT_IPv6_ADDRESS
```
When advertising an external IPv6 address, ensure that the firewall and
network settings of the system running your TURN server are configured to
accept IPv6 traffic, and that the TURN server is listening on the local
IPv6 address that is mapped by NAT to the external IPv6 address.
1. (Re)start the turn server:
@ -216,9 +244,6 @@ connecting". Unfortunately, troubleshooting this can be tricky.
Here are a few things to try:
* Check that your TURN server is not behind NAT. As above, we're not aware of
anyone who has successfully set this up.
* Check that you have opened your firewall to allow TCP and UDP traffic to the
TURN ports (normally 3478 and 5349).
@ -234,6 +259,18 @@ Here are a few things to try:
Try removing any AAAA records for your TURN server, so that it is only
reachable over IPv4.
* If your TURN server is behind NAT:
* double-check that your NAT gateway is correctly forwarding all TURN
ports (normally 3478 & 5349 for TCP & UDP TURN traffic, and 49152-65535 for the UDP
relay) to the NAT-internal address of your TURN server. If advertising
both IPv4 and IPv6 external addresses via the `external-ip` option, ensure
that the NAT is forwarding both IPv4 and IPv6 traffic to the IPv4 and IPv6
internal addresses of your TURN server. When in doubt, remove AAAA records
for your TURN server and specify only an IPv4 address as your `external-ip`.
* ensure that your TURN server uses the NAT gateway as its default route.
* Enable more verbose logging in coturn via the `verbose` setting:
```

View File

@ -85,6 +85,17 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
# Upgrading to v1.50.0
## Dropping support for old Python and Postgres versions
In line with our [deprecation policy](deprecation_policy.md),
we've dropped support for Python 3.6 and PostgreSQL 9.6, as they are no
longer supported upstream.
This release of Synapse requires Python 3.7+ and PostgreSQL 10+.
# Upgrading to v1.47.0
## Removal of old Room Admin API

View File

@ -0,0 +1,139 @@
# Refresh Tokens
Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible).
[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens
## Background and motivation
Synapse users' sessions are identified by **access tokens**; access tokens are
issued to users on login. Each session gets a unique access token which identifies
it; the access token must be kept secret as it grants access to the user's account.
Traditionally, these access tokens were eternally valid (at least until the user
explicitly chose to log out).
In some cases, it may be desirable for these access tokens to expire so that the
potential damage caused by leaking an access token is reduced.
On the other hand, forcing a user to re-authenticate (log in again) often might
be too much of an inconvenience.
**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst
still getting most of the benefits of short access token lifetimes.
Refresh tokens are also a concept present in OAuth 2 — further reading is available
[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5).
When refresh tokens are in use, both an access token and a refresh token will be
issued to users on login. The access token will expire after a predetermined amount
of time, but otherwise works in the same way as before. When the access token is
close to expiring (or has expired), the user's client should present the homeserver
(Synapse) with the refresh token.
The homeserver will then generate a new access token and refresh token for the user
and return them. The old refresh token is invalidated and can not be used again*.
Finally, refresh tokens also make it possible for sessions to be logged out if they
are inactive for too long, before the session naturally ends; see the configuration
guide below.
*To prevent issues if clients lose connection half-way through refreshing a token,
the refresh token is only invalidated once the new access token has been used at
least once. For all intents and purposes, the above simplification is sufficient.
## Caveats
There are some caveats:
* If a third party gets both your access token and refresh token, they will be able to
continue to enjoy access to your session.
* This is still an improvement because you (the user) will notice when *your*
session expires and you're not able to use your refresh token.
That would be a giveaway that someone else has compromised your session.
You would be able to log in again and terminate that session.
Previously (with long-lived access tokens), a third party that has your access
token could go undetected for a very long time.
* Clients need to implement support for refresh tokens in order for them to be a
useful mechanism.
* It is up to homeserver administrators if they want to issue long-lived access
tokens to clients not implementing refresh tokens.
* For compatibility, it is likely that they should, at least until client support
is widespread.
* Users with clients that support refresh tokens will still benefit from the
added security; it's not possible to downgrade a session to using long-lived
access tokens so this effectively gives users the choice.
* In a closed environment where all users use known clients, this may not be
an issue as the homeserver administrator can know if the clients have refresh
token support. In that case, the non-refreshable access token lifetime
may be set to a short duration so that a similar level of security is provided.
## Configuration Guide
The following configuration options, in the `registration` section, are related:
* `session_lifetime`: maximum length of a session, even if it's refreshed.
In other words, the client must log in again after this time period.
In most cases, this can be unset (infinite) or set to a long time (years or months).
* `refreshable_access_token_lifetime`: lifetime of access tokens that are created
by clients supporting refresh tokens.
This should be short; a good value might be 5 minutes (`5m`).
* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created
by clients which don't support refresh tokens.
Make this short if you want to effectively force use of refresh tokens.
Make this long if you don't want to inconvenience users of clients which don't
support refresh tokens (by forcing them to frequently re-authenticate using
login credentials).
* `refresh_token_lifetime`: lifetime of refresh tokens.
In other words, the client must refresh within this time period to maintain its session.
Unless you want to log inactive sessions out, it is often fine to use a long
value here or even leave it unset (infinite).
Beware that making it too short will inconvenience clients that do not connect
very often, including mobile clients and clients of infrequent users (by making
it more difficult for them to refresh in time, which may force them to need to
re-authenticate using login credentials).
**Note:** All four options above only apply when tokens are created (by logging in or refreshing).
Changes to these settings do not apply retroactively.
### Using refresh token expiry to log out inactive sessions
If you'd like to force sessions to be logged out upon inactivity, you can enable
refreshable access token expiry and refresh token expiry.
This works because a client must refresh at least once within a period of
`refresh_token_lifetime` in order to maintain valid credentials to access the
account.
(It's suggested that `refresh_token_lifetime` should be longer than
`refreshable_access_token_lifetime` and this section assumes that to be the case
for simplicity.)
Note: this will only affect sessions using refresh tokens. You may wish to
set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed
by clients that do not support refresh tokens.
#### Choosing values that guarantee permitting some inactivity
It may be desirable to permit some short periods of inactivity, for example to
accommodate brief outages in client connectivity.
The following model aims to provide guidance for choosing `refresh_token_lifetime`
and `refreshable_access_token_lifetime` to satisfy requirements of the form:
1. inactivity longer than `L` **MUST** cause the session to be logged out; and
2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out.
This model makes the weakest assumption that all active clients will refresh as
needed to maintain an active access token, but no sooner.
*In reality, clients may refresh more often than this model assumes, but the
above requirements will still hold.*
To satisfy the above model,
* `refresh_token_lifetime` should be set to `L`; and
* `refreshable_access_token_lifetime` should be set to `L - S`.

View File

@ -25,14 +25,9 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/e2e_room_keys.py
|synapse/storage/databases/main/end_to_end_keys.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
@ -40,12 +35,9 @@ exclude = (?x)
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/room.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/
@ -107,7 +99,6 @@ exclude = (?x)
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
|tests/storage/test_account_data.py
|tests/storage/test_background_update.py
|tests/storage/test_base.py
|tests/storage/test_client_ips.py
@ -145,6 +136,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True
[mypy-synapse.appservice.*]
disallow_untyped_defs = True
[mypy-synapse.config._base]
disallow_untyped_defs = True
@ -163,6 +157,12 @@ disallow_untyped_defs = False
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.http.server]
disallow_untyped_defs = True
[mypy-synapse.logging.context]
disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
@ -181,24 +181,48 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.account_data]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.client_ips]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.e2e_room_keys]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.end_to_end_keys]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_bg_updates]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.stats]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True
@ -223,6 +247,9 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True
[mypy-tests.rest.admin.*]
disallow_untyped_defs = True
[mypy-tests.rest.client.test_directory]
disallow_untyped_defs = True
@ -286,9 +313,6 @@ ignore_missing_imports = True
[mypy-netaddr]
ignore_missing_imports = True
[mypy-opentracing]
ignore_missing_imports = True
[mypy-parameterized.*]
ignore_missing_imports = True

View File

@ -35,7 +35,7 @@
showcontent = true
[tool.black]
target-version = ['py36']
target-version = ['py37', 'py38', 'py39', 'py310']
exclude = '''
(

View File

@ -24,7 +24,6 @@ DISTS = (
"debian:bullseye",
"debian:bookworm",
"debian:sid",
"ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23)
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
"ubuntu:hirsute", # 21.04 (EOL 2022-01-05)
"ubuntu:impish", # 21.10 (EOL 2022-07)

View File

@ -42,8 +42,8 @@ echo "--------------------------"
echo
matched=0
for f in $(git diff --name-only FETCH_HEAD... -- changelog.d); do
# check that any modified newsfiles on this branch end with a full stop.
for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do
# check that any added newsfiles on this branch end with a full stop.
lastchar=$(tr -d '\n' < "$f" | tail -c 1)
if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2

View File

@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0",
"black==21.6b0",
"black==21.12b0",
"flake8-comprehensions",
"flake8-bugbear==21.3.2",
"flake8",
@ -107,6 +107,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
"mypy-zope==0.3.2",
"types-bleach>=4.1.0",
"types-jsonschema>=3.2.0",
"types-opentracing>=2.4.2",
"types-Pillow>=8.3.4",
"types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10",
@ -119,9 +120,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
# Tests assume that all optional dependencies are installed.
#
# parameterized_class decorator was introduced in parameterized 0.7.0
#
# We use `mock` library as that backports `AsyncMock` to Python 3.6
CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
CONDITIONAL_REQUIREMENTS["dev"] = (
CONDITIONAL_REQUIREMENTS["lint"]
@ -163,7 +162,6 @@ setup(
"Topic :: Communications :: Chat",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",

View File

@ -17,11 +17,12 @@
from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol
from twisted.internet.defer import Deferred
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
def ping(self) -> "Deferred[None]": ...
def set(
self,
key: str,
value: Any,
@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...
) -> "Deferred[None]": ...
def get(self, key: str) -> "Deferred[Any]": ...
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...

View File

@ -47,7 +47,7 @@ try:
except ImportError:
pass
__version__ = "1.49.2"
__version__ = "1.50.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@ -32,7 +32,7 @@ from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
@ -149,13 +149,53 @@ class Auth:
is invalid.
AuthError if access is denied for the user in the access token
"""
parent_span = active_span()
with start_active_span("get_user_by_req"):
requester = await self._wrapped_get_user_by_req(
request, allow_guest, rights, allow_expired
)
if parent_span:
if requester.authenticated_entity in self._force_tracing_for_users:
# request tracing is enabled for this user, so we need to force it
# tracing on for the parent span (which will be the servlet span).
#
# It's too late for the get_user_by_req span to inherit the setting,
# so we also force it on for that.
force_tracing()
force_tracing(parent_span)
parent_span.set_tag(
"authenticated_entity", requester.authenticated_entity
)
parent_span.set_tag("user_id", requester.user.to_string())
if requester.device_id is not None:
parent_span.set_tag("device_id", requester.device_id)
if requester.app_service is not None:
parent_span.set_tag("appservice_id", requester.app_service.id)
return requester
async def _wrapped_get_user_by_req(
self,
request: SynapseRequest,
allow_guest: bool,
rights: str,
allow_expired: bool,
) -> Requester:
"""Helper for get_user_by_req
Once get_user_by_req has set up the opentracing span, this does the actual work.
"""
try:
ip_addr = request.getClientIP()
user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request)
(
user_id,
device_id,
app_service,
) = await self._get_appservice_user_id_and_device_id(request)
if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
@ -163,18 +203,16 @@ class Auth:
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
device_id="dummy-device", # stubbed
device_id="dummy-device"
if device_id is None
else device_id, # stubbed
)
requester = create_requester(user_id, app_service=app_service)
requester = create_requester(
user_id, app_service=app_service, device_id=device_id
)
request.requester = user_id
if user_id in self._force_tracing_for_users:
opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
return requester
user_info = await self.get_user_by_access_token(
@ -232,13 +270,6 @@ class Auth:
)
request.requester = requester
if user_info.token_owner in self._force_tracing_for_users:
opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
return requester
except KeyError:
raise MissingClientTokenError()
@ -275,33 +306,81 @@ class Auth:
403, "Application service has not registered this user (%s)" % user_id
)
async def _get_appservice_user_id(
async def _get_appservice_user_id_and_device_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
"""
Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request
- what user the application service should be treated as controlling
(the user_id URI parameter allows an application service to masquerade
any applicable user in its namespace)
- what device the application service should be treated as controlling
(the device_id[^1] URI parameter allows an application service to masquerade
as any device that exists for the relevant user)
[^1] Unstable and provided by MSC3202.
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
Returns:
3-tuple of
(user ID?, device ID?, application service?)
Postconditions:
- If an application service is returned, so is a user ID
- A user ID is never returned without an application service
- A device ID is never returned without a user ID or an application service
- The returned application service, if present, is permitted to control the
returned user ID.
- The returned device ID, if present, has been checked to be a valid device ID
for the returned user ID.
"""
DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id"
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
if app_service is None:
return None, None
return None, None, None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientIP())
if ip_address not in app_service.ip_range_whitelist:
return None, None
return None, None, None
# This will always be set by the time Twisted calls us.
assert request.args is not None
if b"user_id" not in request.args:
return app_service.sender, app_service
if b"user_id" in request.args:
effective_user_id = request.args[b"user_id"][0].decode("utf8")
await self.validate_appservice_can_control_user_id(
app_service, effective_user_id
)
else:
effective_user_id = app_service.sender
user_id = request.args[b"user_id"][0].decode("utf8")
await self.validate_appservice_can_control_user_id(app_service, user_id)
effective_device_id: Optional[str] = None
if app_service.sender == user_id:
return app_service.sender, app_service
if (
self.hs.config.experimental.msc3202_device_masquerading_enabled
and DEVICE_ID_ARG_NAME in request.args
):
effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
# We only just set this so it can't be None!
assert effective_device_id is not None
device_opt = await self.store.get_device(
effective_user_id, effective_device_id
)
if device_opt is None:
# For now, use 400 M_EXCLUSIVE if the device doesn't exist.
# This is an open thread of discussion on MSC3202 as of 2021-12-09.
raise AuthError(
400,
f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})",
Codes.EXCLUSIVE,
)
return user_id, app_service
return effective_user_id, effective_device_id, app_service
async def get_user_by_access_token(
self,

View File

@ -253,5 +253,9 @@ class GuestAccess:
FORBIDDEN: Final = "forbidden"
class ReceiptTypes:
READ: Final = "m.read"
class ReadReceiptEventFields:
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"

View File

@ -351,8 +351,7 @@ class Filter:
True if the event matches the filter.
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
# namedtuple type.
# except for presence which actually gets passed around as its own type.
if isinstance(event, UserPresenceState):
user_id = event.user_id
field_matchers = {

View File

@ -27,6 +27,7 @@ import synapse
import synapse.config.logger
from synapse import events
from synapse.api.urls import (
CLIENT_API_PREFIX,
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX,
@ -192,13 +193,7 @@ class SynapseHomeServer(HomeServer):
resources.update(
{
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
"/_matrix/client/v1": client_resource,
"/_matrix/client/v3": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
CLIENT_API_PREFIX: client_resource,
"/.well-known": well_known_resource(self),
"/_synapse/admin": AdminRestResource(self),
**build_synapse_client_resource_tree(self),

View File

@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
import attr
from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
UP = "up"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Namespace:
exclusive: bool
group_id: Optional[str]
regex: Pattern[str]
class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@ -50,17 +61,17 @@ class ApplicationService:
def __init__(
self,
token,
hostname,
id,
sender,
url=None,
namespaces=None,
hs_token=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
supports_ephemeral=False,
token: str,
hostname: str,
id: str,
sender: str,
url: Optional[str] = None,
namespaces: Optional[JsonDict] = None,
hs_token: Optional[str] = None,
protocols: Optional[Iterable[str]] = None,
rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None,
supports_ephemeral: bool = False,
):
self.token = token
self.url = (
@ -85,27 +96,33 @@ class ApplicationService:
self.rate_limited = rate_limited
def _check_namespaces(self, namespaces):
def _check_namespaces(
self, namespaces: Optional[JsonDict]
) -> Dict[str, List[Namespace]]:
# Sanity check that it is of the form:
# {
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
if namespaces is None:
namespaces = {}
result: Dict[str, List[Namespace]] = {}
for ns in ApplicationService.NS_LIST:
result[ns] = []
if ns not in namespaces:
namespaces[ns] = []
continue
if type(namespaces[ns]) != list:
if not isinstance(namespaces[ns], list):
raise ValueError("Bad namespace value for '%s'" % ns)
for regex_obj in namespaces[ns]:
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
exclusive = regex_obj.get("exclusive")
if not isinstance(exclusive, bool):
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
@ -126,22 +143,26 @@ class ApplicationService:
)
regex = regex_obj.get("regex")
if isinstance(regex, str):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
if not isinstance(regex, str):
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string):
return regex_obj
# Pre-compile regex.
result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
return result
def _matches_regex(
self, namespace_key: str, test_string: str
) -> Optional[Namespace]:
for namespace in self.namespaces[namespace_key]:
if namespace.regex.match(test_string):
return namespace
return None
def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
namespace = self._matches_regex(namespace_key, test_string)
if namespace:
return namespace.exclusive
return False
async def _matches_user(
@ -260,15 +281,15 @@ class ApplicationService:
def is_interested_in_user(self, user_id: str) -> bool:
return (
bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
or user_id == self.sender
)
def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
def is_exclusive_user(self, user_id: str) -> bool:
return (
@ -285,14 +306,14 @@ class ApplicationService:
def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self):
def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
return [
regex_obj["regex"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if regex_obj["exclusive"]
namespace.regex
for namespace in self.namespaces[ApplicationService.NS_USERS]
if namespace.exclusive
]
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@ -305,15 +326,15 @@ class ApplicationService:
An iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
namespace.group_id
for namespace in self.namespaces[ApplicationService.NS_USERS]
if namespace.group_id and namespace.regex.match(user_id)
)
def is_rate_limited(self) -> bool:
return self.rate_limited
def __str__(self):
def __str__(self) -> str:
# copy dictionary and redact token fields so they don't get logged
dict_copy = self.__dict__.copy()
dict_copy["token"] = "<redacted>"

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, List, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
return True
def _is_valid_3pe_result(r, field):
def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
if not isinstance(r, dict):
return False
@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
async def query_user(self, service, user_id):
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False
# This is required by the configuration.
assert service.hs_token is not None
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False
async def query_alias(self, service, alias):
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
if service.url is None:
return False
# This is required by the configuration.
assert service.hs_token is not None
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False
async def query_3pe(self, service, kind, protocol, fields):
async def query_3pe(
self,
service: "ApplicationService",
kind: str,
protocol: str,
fields: Dict[bytes, List[bytes]],
) -> List[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient):
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
):
) -> bool:
if service.url is None:
return True
events = self._serialize(service, events)
# This is required by the configuration.
assert service.hs_token is not None
serialized_events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient):
# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
body = {
"events": serialized_events,
"de.sorunome.msc2409.ephemeral": ephemeral,
}
else:
body = {"events": events}
body = {"events": serialized_events}
try:
await self.put_json(
@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
[event.get("event_id") for event in events],
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
sent_events_counter.labels(service.id).inc(len(serialized_events))
return True
except CodeMessageException as e:
logger.warning(
@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
def _serialize(self, service, events):
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(

View File

@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import List, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice.api import ApplicationServiceApi
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
from synapse.util import Clock
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
case is a simple array.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.as_api = hs.get_application_service_api()
@ -80,7 +86,7 @@ class ApplicationServiceScheduler:
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
async def start(self):
async def start(self) -> None:
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
@ -91,12 +97,14 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
def submit_event_for_as(
self, service: ApplicationService, event: EventBase
) -> None:
self.queuer.enqueue_event(service, event)
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[JsonDict]
):
) -> None:
self.queuer.enqueue_ephemeral(service, events)
@ -108,16 +116,18 @@ class _ServiceQueuer:
appservice at a given time.
"""
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
def _start_background_request(self, service):
def _start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
@ -126,15 +136,17 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
def enqueue_event(self, service: ApplicationService, event: EventBase):
def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
def enqueue_ephemeral(
self, service: ApplicationService, events: List[JsonDict]
) -> None:
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
async def _send_request(self, service: ApplicationService):
async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@ -168,20 +180,15 @@ class _TransactionController:
if a transaction fails.
(Note we have only have one of these in the homeserver.)
Args:
clock (synapse.util.Clock):
store (synapse.storage.DataStore):
as_api (synapse.appservice.api.ApplicationServiceApi):
"""
def __init__(self, clock, store, as_api):
def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
self.clock = clock
self.store = store
self.as_api = as_api
# map from service id to recoverer instance
self.recoverers = {}
self.recoverers: Dict[str, "_Recoverer"] = {}
# for UTs
self.RECOVERER_CLASS = _Recoverer
@ -191,7 +198,7 @@ class _TransactionController:
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
):
) -> None:
try:
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral or []
@ -207,7 +214,7 @@ class _TransactionController:
logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service)
async def on_recovered(self, recoverer):
async def on_recovered(self, recoverer: "_Recoverer") -> None:
logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id
)
@ -217,18 +224,18 @@ class _TransactionController:
recoverer.service, ApplicationServiceState.UP
)
async def _on_txn_fail(self, service):
async def _on_txn_fail(self, service: ApplicationService) -> None:
try:
await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service)
except Exception:
logger.exception("Error starting AS recoverer")
def start_recoverer(self, service):
def start_recoverer(self, service: ApplicationService) -> None:
"""Start a Recoverer for the given service
Args:
service (synapse.appservice.ApplicationService):
service:
"""
logger.info("Starting recoverer for AS ID %s", service.id)
assert service.id not in self.recoverers
@ -257,7 +264,14 @@ class _Recoverer:
callback (callable[_Recoverer]): called once the service recovers.
"""
def __init__(self, clock, store, as_api, service, callback):
def __init__(
self,
clock: Clock,
store: DataStore,
as_api: ApplicationServiceApi,
service: ApplicationService,
callback: Callable[["_Recoverer"], Awaitable[None]],
):
self.clock = clock
self.store = store
self.as_api = as_api
@ -265,8 +279,8 @@ class _Recoverer:
self.callback = callback
self.backoff_counter = 1
def recover(self):
def _retry():
def recover(self) -> None:
def _retry() -> None:
run_as_background_process(
"as-recoverer-%s" % (self.service.id,), self.retry
)
@ -275,13 +289,13 @@ class _Recoverer:
logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
self.clock.call_later(delay, _retry)
def _backoff(self):
def _backoff(self) -> None:
# cap the backoff to be around 8.5min => (2^9) = 512 secs
if self.backoff_counter < 9:
self.backoff_counter += 1
self.recover()
async def retry(self):
async def retry(self) -> None:
logger.info("Starting retries on %s", self.service.id)
try:
while True:

View File

@ -107,6 +107,8 @@ _DEFAULT_PREJOIN_STATE_TYPES = [
EventTypes.Name,
# Per MSC1772.
EventTypes.Create,
# Per MSC3173.
EventTypes.Topic,
]

View File

@ -147,8 +147,7 @@ def _load_appservice(
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
if not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):

View File

@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
# MSC2716 (backfill existing history)
# MSC2716 (importing historical messages)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
# MSC2285 (hidden read receipts)
@ -49,3 +49,8 @@ class ExperimentalConfig(Config):
# MSC3030 (Jump to date API endpoint)
self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
# The portion of MSC3202 which is related to device masquerading.
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)

View File

@ -16,12 +16,14 @@
import hashlib
import logging
import os
from typing import Any, Dict
from typing import Any, Dict, Iterator, List, Optional
import attr
import jsonschema
from signedjson.key import (
NACL_ED25519,
SigningKey,
VerifyKey,
decode_signing_key_base64,
decode_verify_key_bytes,
generate_signing_key,
@ -31,6 +33,7 @@ from signedjson.key import (
)
from unpaddedbase64 import decode_base64
from synapse.types import JsonDict
from synapse.util.stringutils import random_string, random_string_with_symbols
from ._base import Config, ConfigError
@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set
logger = logging.getLogger(__name__)
@attr.s
@attr.s(slots=True, auto_attribs=True)
class TrustedKeyServer:
# string: name of the server.
server_name = attr.ib()
# name of the server.
server_name: str
# dict[str,VerifyKey]|None: map from key id to key object, or None to disable
# signature verification.
verify_keys = attr.ib(default=None)
# map from key id to key object, or None to disable signature verification.
verify_keys: Optional[Dict[str, VerifyKey]] = None
class KeyConfig(Config):
@ -279,15 +281,15 @@ class KeyConfig(Config):
% locals()
)
def read_signing_keys(self, signing_key_path, name):
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
"""Read the signing keys in the given path.
Args:
signing_key_path (str)
name (str): Associated config key name
signing_key_path
name: Associated config key name
Returns:
list[SigningKey]
The signing keys read from the given path.
"""
signing_keys = self.read_file(signing_key_path, name)
@ -296,7 +298,9 @@ class KeyConfig(Config):
except Exception as e:
raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict]
) -> Dict[str, VerifyKey]:
if old_signing_keys is None:
return {}
keys = {}
@ -340,7 +344,7 @@ class KeyConfig(Config):
write_signing_keys(signing_key_file, (key,))
def _perspectives_to_key_servers(config):
def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]:
"""Convert old-style 'perspectives' configs into new-style 'trusted_key_servers'
Returns an iterable of entries to add to trusted_key_servers.
@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = {
}
def _parse_key_servers(key_servers, federation_verify_certificates):
def _parse_key_servers(
key_servers: List[Any], federation_verify_certificates: bool
) -> Iterator[TrustedKeyServer]:
try:
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
except jsonschema.ValidationError as e:
@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
yield result
def _assert_keyserver_has_verify_keys(trusted_key_server):
def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None:
if not trusted_key_server.verify_keys:
raise ConfigError(INSECURE_NOTARY_ERROR)

View File

@ -22,10 +22,12 @@ from ._base import Config, ConfigError
@attr.s
class MetricsFlags:
known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
known_servers: bool = attr.ib(
default=False, validator=attr.validators.instance_of(bool)
)
@classmethod
def all_off(cls):
def all_off(cls) -> "MetricsFlags":
"""
Instantiate the flags with all options set to off.
"""

View File

@ -37,7 +37,7 @@ class ModulesConfig(Config):
# Server admins can expand Synapse's functionality with external modules.
#
# See https://matrix-org.github.io/synapse/latest/modules.html for more
# See https://matrix-org.github.io/synapse/latest/modules/index.html for more
# documentation on how to configure or create custom modules for Synapse.
#
modules:

View File

@ -14,10 +14,11 @@
import logging
import os
from collections import namedtuple
from typing import Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
import attr
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\
HTTP_PROXY_SET_WARNING = """\
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig",
(
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
),
)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThumbnailRequirement:
width: int
height: int
method: str
media_type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class MediaStorageProviderConfig:
store_local: bool # Whether to store newly uploaded local files
store_remote: bool # Whether to store newly downloaded remote files
store_synchronous: bool # Whether to wait for successful storage for local uploads
def parse_thumbnail_requirements(
@ -66,11 +69,10 @@ def parse_thumbnail_requirements(
method, and thumbnail media type to precalculate
Args:
thumbnail_sizes(list): List of dicts with "width", "height", and
"method" keys
thumbnail_sizes: List of dicts with "width", "height", and "method" keys
Returns:
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
Dictionary mapping from media type string to list of ThumbnailRequirement.
"""
requirements: Dict[str, List[ThumbnailRequirement]] = {}
for size in thumbnail_sizes:

View File

@ -15,8 +15,9 @@
from typing import List
from matrix_common.regex import glob_to_regex
from synapse.types import JsonDict
from synapse.util import glob_to_regex
from ._base import Config, ConfigError

View File

@ -1257,7 +1257,7 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.",
)
def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
"""Reads the three durations for the GC min interval option, returning seconds."""
if durations is None:
return None

View File

@ -16,11 +16,12 @@ import logging
import os
from typing import List, Optional, Pattern
from matrix_common.regex import glob_to_regex
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
from synapse.util import glob_to_regex
logger = logging.getLogger(__name__)
@ -132,7 +133,7 @@ class TlsConfig(Config):
self.tls_certificate: Optional[crypto.X509] = None
self.tls_private_key: Optional[crypto.PKey] = None
def read_certificate_from_disk(self):
def read_certificate_from_disk(self) -> None:
"""
Read the certificates and private key from disk.
"""

View File

@ -395,7 +395,7 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase],
time_now: int,
*,
bundle_aggregations: bool = True,
bundle_aggregations: bool = False,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
@ -454,23 +454,26 @@ class EventClientSerializer:
return
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include.
aggregations = {}
annotations = await self.store.get_aggregation_groups_for_event(event_id)
annotations = await self.store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id, room_id)
if edit:
# If there is an edit replace the content, preserving existing
@ -503,7 +506,7 @@ class EventClientSerializer:
(
thread_count,
latest_thread_event,
) = await self.store.get_thread_summary(event_id)
) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
@ -104,10 +103,6 @@ class FederationBase:
return pdu
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
async def _check_sigs_on_pdu(
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
) -> None:
@ -220,15 +215,12 @@ def _is_invite_via_3pid(event: EventBase) -> bool:
)
def event_from_pdu_json(
pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
) -> EventBase:
def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase:
"""Construct an EventBase from an event json received over federation
Args:
pdu_json: pdu as received over federation
room_version: The version of the room this event belongs to
outlier: True to mark this event as an outlier
Raises:
SynapseError: if the pdu is missing required fields or is otherwise
@ -252,6 +244,4 @@ def event_from_pdu_json(
validate_canonicaljson(pdu_json)
event = make_event_from_dict(pdu_json, room_version)
event.internal_metadata.outlier = outlier
return event

View File

@ -265,14 +265,11 @@ class FederationClient(FederationBase):
room_version = await self.store.get_room_version(room_id)
pdus = [
event_from_pdu_json(p, room_version, outlier=False)
for p in transaction_data_pdus
]
pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus]
# Check signatures and hash of pdus, removing any from the list that fail checks
pdus[:] = await self._check_sigs_and_hash_and_fetch(
dest, pdus, outlier=True, room_version=room_version
dest, pdus, room_version=room_version
)
return pdus
@ -282,7 +279,6 @@ class FederationClient(FederationBase):
destination: str,
event_id: str,
room_version: RoomVersion,
outlier: bool = False,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home
@ -292,9 +288,6 @@ class FederationClient(FederationBase):
destination: Which homeserver to query
event_id: event to fetch
room_version: version of the room
outlier: Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitrary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@ -316,8 +309,7 @@ class FederationClient(FederationBase):
)
pdu_list: List[EventBase] = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
event_from_pdu_json(p, room_version) for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
@ -334,7 +326,6 @@ class FederationClient(FederationBase):
destinations: Iterable[str],
event_id: str,
room_version: RoomVersion,
outlier: bool = False,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home
@ -347,9 +338,6 @@ class FederationClient(FederationBase):
destinations: Which homeservers to query
event_id: event to fetch
room_version: version of the room
outlier: Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitrary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@ -377,7 +365,6 @@ class FederationClient(FederationBase):
destination=destination,
event_id=event_id,
room_version=room_version,
outlier=outlier,
timeout=timeout,
)
@ -435,7 +422,6 @@ class FederationClient(FederationBase):
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
outlier: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
@ -451,7 +437,6 @@ class FederationClient(FederationBase):
origin
pdu
room_version
outlier: Whether the events are outliers or not
Returns:
A list of PDUs that have valid signatures and hashes.
@ -466,7 +451,6 @@ class FederationClient(FederationBase):
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=origin,
outlier=outlier,
room_version=room_version,
)
@ -482,7 +466,6 @@ class FederationClient(FederationBase):
pdu: EventBase,
origin: str,
room_version: RoomVersion,
outlier: bool = False,
) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
its signature check then we check if we have it in the database and if
@ -494,9 +477,6 @@ class FederationClient(FederationBase):
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
The PDU (possibly redacted) if it has valid signatures and hashes.
@ -521,7 +501,6 @@ class FederationClient(FederationBase):
destinations=[pdu_origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
except SynapseError:
@ -541,13 +520,10 @@ class FederationClient(FederationBase):
room_version = await self.store.get_room_version(room_id)
auth_chain = [
event_from_pdu_json(p, room_version, outlier=True)
for p in res["auth_chain"]
]
auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]]
signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version
destination, auth_chain, room_version=room_version
)
return signed_auth
@ -816,7 +792,6 @@ class FederationClient(FederationBase):
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=event,
origin=destination,
outlier=True,
room_version=room_version,
)
@ -864,7 +839,6 @@ class FederationClient(FederationBase):
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=destination,
outlier=True,
room_version=room_version,
)
@ -1235,7 +1209,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version
destination, events, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:

View File

@ -28,9 +28,9 @@ from typing import (
Union,
)
from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@ -66,8 +66,8 @@ from synapse.replication.http.federation import (
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
@ -360,13 +360,13 @@ class FederationServer(FederationBase):
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
pdu_results, _ = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(
self._handle_pdus_in_txn, origin, transaction, request_time
),
run_in_background(self._handle_edus_in_txn, origin, transaction),
],
),
consumeErrors=True,
).addErrback(unwrapFirstError)
)

View File

@ -30,7 +30,6 @@ Events are replicated via a separate events stream.
"""
import logging
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Dict,
@ -43,6 +42,7 @@ from typing import (
Type,
)
import attr
from sortedcontainers import SortedDict
from synapse.api.presence import UserPresenceState
@ -382,13 +382,11 @@ class BaseFederationRow:
raise NotImplementedError()
class PresenceDestinationsRow(
BaseFederationRow,
namedtuple(
"PresenceDestinationsRow",
("state", "destinations"), # UserPresenceState # list[str]
),
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceDestinationsRow(BaseFederationRow):
state: UserPresenceState
destinations: List[str]
TypeId = "pd"
@staticmethod
@ -404,17 +402,15 @@ class PresenceDestinationsRow(
buff.presence_destinations.append((self.state, self.destinations))
class KeyedEduRow(
BaseFederationRow,
namedtuple(
"KeyedEduRow",
("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
),
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class KeyedEduRow(BaseFederationRow):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
key: Tuple[str, ...] # the edu key passed to send_edu
edu: Edu
TypeId = "k"
@staticmethod
@ -428,9 +424,12 @@ class KeyedEduRow(
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EduRow(BaseFederationRow):
"""Streams EDUs that don't have keys. See KeyedEduRow"""
edu: Edu
TypeId = "e"
@staticmethod
@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
"ParsedFederationStreamData",
(
"presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ParsedFederationStreamData:
# list of tuples of UserPresenceState and destinations
presence_destinations: List[Tuple[UserPresenceState, List[str]]]
# dict of destination -> { key -> Edu }
keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]]
# dict of destination -> [Edu]
edus: Dict[str, List[Edu]]
def process_rows_for_federation(

View File

@ -22,13 +22,11 @@ from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.http.server import HttpServer, ServletCallback
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
SynapseTags,
start_active_span,
start_active_span_from_request,
tags,
set_tag,
span_context_from_request,
start_active_span_follows_from,
whitelisted_homeserver,
)
from synapse.server import HomeServer
@ -279,30 +277,19 @@ class BaseFederationServlet:
logger.warning("authenticate_request failed: %s", e)
raise
request_tags = {
SynapseTags.REQUEST_ID: request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
"authenticated_entity": origin,
"servlet_name": request.request_metrics.name,
}
# update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin)
# Only accept the span context if the origin is authenticated
# and whitelisted
# if the origin is authenticated and whitelisted, link to its span context
context = None
if origin and whitelisted_homeserver(origin):
scope = start_active_span_from_request(
request, "incoming-federation-request", tags=request_tags
)
else:
scope = start_active_span(
"incoming-federation-request", tags=request_tags
)
context = span_context_from_request(request)
scope = start_active_span_follows_from(
"incoming-federation-request", contexts=(context,) if context else ()
)
with scope:
opentracing.inject_response_headers(request.responseHeaders)
if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d

View File

@ -462,9 +462,9 @@ class ApplicationServicesHandler:
Args:
room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
RoomAliasMapping or None if no association can be found.
"""
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()

View File

@ -997,9 +997,7 @@ class AuthHandler:
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
try:
await self.store.get_device(user_id, device_id)
except StoreError:
if await self.store.get_device(user_id, device_id) is None:
await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")

View File

@ -106,10 +106,10 @@ class DeviceWorkerHandler:
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = await self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
device = await self.store.get_device(user_id, device_id)
if device is None:
raise errors.NotFoundError()
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
@ -602,6 +602,8 @@ class DeviceHandler(DeviceWorkerHandler):
access_token, device_id
)
old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None:
raise errors.NotFoundError()
await self.store.update_device(user_id, device_id, old_device["display_name"])
# can't call self.delete_device because that will clobber the
# access token so call the storage layer directly

View File

@ -280,13 +280,15 @@ class DirectoryHandler:
users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
servers_set = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = [self.server_name] + [s for s in servers if s != self.server_name]
if self.server_name in servers_set:
servers = [self.server_name] + [
s for s in servers_set if s != self.server_name
]
else:
servers = list(servers)
servers = list(servers_set)
return {"room_id": room_id, "servers": servers}

View File

@ -65,8 +65,12 @@ class E2eKeysHandler:
else:
# Only register this edu handler on master as it requires writing
# device updates to the db
#
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
federation_registry.register_edu_handler(
"m.signing_key_update",
self._edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
"org.matrix.signing_key_update",
self._edu_updater.incoming_signing_key_update,
@ -576,7 +580,9 @@ class E2eKeysHandler:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
fallback_keys = keys.get("fallback_keys") or keys.get(
"org.matrix.msc2732.fallback_keys"
)
if fallback_keys and isinstance(fallback_keys, dict):
log_kv(
{

View File

@ -14,7 +14,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Dict, Optional
from typing_extensions import Literal
from synapse.api.errors import (
Codes,
@ -24,6 +26,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
@ -58,7 +61,9 @@ class E2eRoomKeysHandler:
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> List[JsonDict]:
) -> Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@ -72,8 +77,8 @@ class E2eRoomKeysHandler:
Raises:
NotFoundError: if the backup version does not exist
Returns:
A list of dicts giving the session_data and message metadata for
these room keys.
A dict giving the session_data and message metadata for these room keys.
`{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
# we deliberately take the lock to get keys so that changing the version
@ -273,7 +278,7 @@ class E2eRoomKeysHandler:
@staticmethod
def _should_replace_room_key(
current_room_key: Optional[JsonDict], room_key: JsonDict
current_room_key: Optional[RoomKey], room_key: RoomKey
) -> bool:
"""
Determine whether to replace a given current_room_key (if any)

View File

@ -79,13 +79,14 @@ class EventStreamHandler:
# thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = await self.notifier.get_events_for(
stream_result = await self.notifier.get_events_for(
auth_user,
pagin_config,
timeout,
is_guest=is_guest,
explicit_room_id=room_id,
)
events = stream_result.events
time_now = self.clock.time_msec()
@ -122,14 +123,12 @@ class EventStreamHandler:
events,
time_now,
as_client_event=as_client_event,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
)
chunk = {
"chunk": chunks,
"start": await tokens[0].to_string(self.store),
"end": await tokens[1].to_string(self.store),
"start": await stream_result.start_token.to_string(self.store),
"end": await stream_result.end_token.to_string(self.store),
}
return chunk

View File

@ -360,31 +360,34 @@ class FederationHandler:
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = await make_deferred_yieldable(
states_list = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
)
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
# A map from event_id to state map of event_ids.
state_ids: Dict[str, StateMap[str]] = dict(
zip(event_ids, [s.state for s in states_list])
)
state_map = await self.store.get_events(
[e_id for ids in states.values() for e_id in ids.values()],
[e_id for ids in state_ids.values() for e_id in ids.values()],
get_prev_content=False,
)
states = {
# A map from event_id to state map of events.
state_events: Dict[str, StateMap[EventBase]] = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
}
for key, state_dict in states.items()
for key, state_dict in state_ids.items()
}
for e_id in event_ids:
likely_extremeties_domains = get_domains_from_state(states[e_id])
likely_extremeties_domains = get_domains_from_state(state_events[e_id])
success = await try_backfill(
[

View File

@ -421,9 +421,6 @@ class FederationEventHandler:
Raises:
SynapseError if the response is in some way invalid.
"""
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
create_event = None
@ -666,7 +663,9 @@ class FederationEventHandler:
logger.info("Processing pulled event %s", event)
# these should not be outliers.
assert not event.internal_metadata.is_outlier()
assert (
not event.internal_metadata.is_outlier()
), "pulled event unexpectedly flagged as outlier"
event_id = event.event_id
@ -1192,7 +1191,6 @@ class FederationEventHandler:
[destination],
event_id,
room_version,
outlier=True,
)
if event is None:
logger.warning(
@ -1221,9 +1219,10 @@ class FederationEventHandler:
"""Persist a batch of outlier events fetched from remote servers.
We first sort the events to make sure that we process each event's auth_events
before the event itself, and then auth and persist them.
before the event itself.
Notifies about the events where appropriate.
We then mark the events as outliers, persist them to the database, and, where
appropriate (eg, an invite), awake the notifier.
Params:
room_id: the room that the events are meant to be in (though this has
@ -1274,7 +1273,8 @@ class FederationEventHandler:
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.
Notifies about the events where appropriate.
Marks the events as outliers, auths them, persists them to the database, and,
where appropriate (eg, an invite), awakes the notifier.
Params:
origin: where the events came from
@ -1312,6 +1312,9 @@ class FederationEventHandler:
return None
auth.append(ae)
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
context = EventContext.for_outlier()
try:
validate_event_for_room_version(room_version_obj, event)
@ -1838,7 +1841,7 @@ class FederationEventHandler:
The stream ID after which all events have been persisted.
"""
if not event_and_contexts:
return self._store.get_current_events_token()
return self._store.get_room_max_stream_ordering()
instance = self._config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:

View File

@ -13,21 +13,27 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
)
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.async_helpers import concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
@ -167,8 +173,6 @@ class InitialSyncHandler:
d["invite"] = await self._event_serializer.serialize_event(
invite_event,
time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
@ -190,14 +194,13 @@ class InitialSyncHandler:
)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
(messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
@ -205,7 +208,7 @@ class InitialSyncHandler:
end_token=room_end_token,
),
deferred_room_state,
]
)
)
).addErrback(unwrapFirstError)
@ -222,8 +225,6 @@ class InitialSyncHandler:
await self._event_serializer.serialize_events(
messages,
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
),
@ -234,8 +235,6 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
@ -377,9 +376,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
await self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
@ -387,7 +384,7 @@ class InitialSyncHandler:
"state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
room_state.values(), time_now, bundle_aggregations=False
room_state.values(), time_now
)
),
"presence": [],
@ -408,7 +405,7 @@ class InitialSyncHandler:
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
current_state.values(), time_now, bundle_aggregations=False
current_state.values(), time_now
)
now_token = self.hs.get_event_sources().get_current_token()
@ -454,8 +451,8 @@ class InitialSyncHandler:
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(get_presence),
run_in_background(get_receipts),
run_in_background(
@ -464,7 +461,7 @@ class InitialSyncHandler:
limit=limit,
end_token=now_token.room_key,
),
],
),
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@ -483,9 +480,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
await self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),

View File

@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@ -498,6 +497,7 @@ class EventCreationHandler:
require_consent: bool = True,
outlier: bool = False,
historical: bool = False,
allow_no_prev_events: bool = False,
depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""
@ -609,6 +609,7 @@ class EventCreationHandler:
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
depth=depth,
allow_no_prev_events=allow_no_prev_events,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@ -884,6 +885,7 @@ class EventCreationHandler:
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
allow_no_prev_events: bool = False,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@ -914,6 +916,7 @@ class EventCreationHandler:
full_state_ids_at_event = None
if auth_event_ids is not None:
# If auth events are provided, prev events must be also.
# prev_event_ids could be an empty array though.
assert prev_event_ids is not None
# Copy the full auth state before it stripped down
@ -945,14 +948,22 @@ class EventCreationHandler:
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
# we now ought to have some prev_events (unless it's a create event).
#
# do a quick sanity check here, rather than waiting until we've created the
# Do a quick sanity check here, rather than waiting until we've created the
# event and then try to auth it (which fails with a somewhat confusing "No
# create event in auth events")
assert (
builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events"
if allow_no_prev_events:
# We allow events with no `prev_events` but it better have some `auth_events`
assert (
builder.type == EventTypes.Create
# Allow an event to have empty list of prev_event_ids
# only if it has auth_event_ids.
or auth_event_ids
), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
else:
# we now ought to have some prev_events (unless it's a create event).
assert (
builder.type == EventTypes.Create or prev_event_ids
), "Attempting to create a non-m.room.create event with no prev_events"
event = await builder.build(
prev_event_ids=prev_event_ids,
@ -1158,9 +1169,9 @@ class EventCreationHandler:
# We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
result = await make_deferred_yieldable(
defer.gatherResults(
[
result, _ = await make_deferred_yieldable(
gather_results(
(
run_in_background(
self._persist_event,
requester=requester,
@ -1172,12 +1183,12 @@ class EventCreationHandler:
run_in_background(
self.cache_joined_hosts_for_event, event, context
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
],
),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
return result[0]
return result
async def _persist_event(
self,

View File

@ -542,7 +542,10 @@ class PaginationHandler:
chunk = {
"chunk": (
await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
events,
time_now,
bundle_aggregations=True,
as_client_event=as_client_event,
)
),
"start": await from_token.to_string(self.store),

View File

@ -729,7 +729,7 @@ class PresenceHandler(BasePresenceHandler):
# Presence is best effort and quickly heals itself, so lets just always
# stream from the current state when we restart.
self._event_pos = self.store.get_current_events_token()
self._event_pos = self.store.get_room_max_stream_ordering()
self._event_processing = False
async def _on_shutdown(self) -> None:

View File

@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@ -179,7 +179,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
for event_id in content.keys():
event_content = content.get(event_id, {})
m_read = event_content.get("m.read", {})
m_read = event_content.get(ReceiptTypes.READ, {})
# If m_read is missing copy over the original event_content as there is nothing to process here
if not m_read:
@ -207,7 +207,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
# Set new users unless empty
if len(new_users.keys()) > 0:
new_event["content"][event_id] = {"m.read": new_users}
new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
# Append new_event to visible_events unless empty
if len(new_event["content"].keys()) > 0:

View File

@ -172,7 +172,7 @@ class RoomCreationHandler:
user_id = requester.user.to_string()
# Check if this room is already being upgraded by another person
for key in self._upgrade_response_cache.pending_result_cache:
for key in self._upgrade_response_cache.keys():
if key[0] == old_room_id and key[1] != user_id:
# Two different people are trying to upgrade the same room.
# Send the second an error.

View File

@ -13,9 +13,9 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
import attr
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@ -474,16 +474,12 @@ class RoomListHandler:
)
class RoomListNextBatch(
namedtuple(
"RoomListNextBatch",
(
"last_joined_members", # The count to get rooms after/before
"last_room_id", # The room_id to get rooms after/before
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
),
)
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomListNextBatch:
last_joined_members: int # The count to get rooms after/before
last_room_id: str # The room_id to get rooms after/before
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
@ -502,12 +498,12 @@ class RoomListNextBatch(
def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
)
)
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)
return attr.evolve(self, **kwds)
def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:

View File

@ -658,7 +658,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
if prev_event_ids:
# An empty prev_events list is allowed as long as the auth_event_ids are present
if prev_event_ids is not None:
return await self._local_membership_update(
requester=requester,
target=target,
@ -1019,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room and old_room["is_public"]:
if old_room is not None and old_room["is_public"]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
@ -1030,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
# Add new the new room to those groups
await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
await self.store.add_room_to_group(
group_id, room_id, old_room is not None and old_room["is_public"]
)
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)

View File

@ -80,6 +80,17 @@ class StatsHandler:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos > room_max_stream_ordering:
# apparently, we've processed more events than exist in the database!
# this can happen if events are removed with history purge or similar.
logger.warning(
"Event stream ordering appears to have gone backwards (%i -> %i): "
"rewinding stats processor",
self.pos,
room_max_stream_ordering,
)
self.pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date

View File

@ -28,7 +28,7 @@ from typing import (
import attr
from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -36,6 +36,7 @@ from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@ -421,7 +422,7 @@ class SyncHandler:
span to track the sync. See `generate_sync_result` for the next part of your
indoctrination.
"""
with start_active_span("current_sync_for_user"):
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
@ -1040,18 +1041,17 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> Dict[str, int]:
) -> NotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read",
receipt_type=ReceiptTypes.READ,
)
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
async def generate_sync_result(
self,
@ -1584,7 +1584,8 @@ class SyncHandler:
)
logger.debug("Generated room entry for %s", room_entry.room_id)
await concurrently_execute(handle_room_entries, room_entries, 10)
with start_active_span("sync.generate_room_entries"):
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
@ -1661,20 +1662,20 @@ class SyncHandler:
) -> _RoomChanges:
"""Determine the changes in rooms to report to the user.
Ideally, we want to report all events whose stream ordering `s` lies in the
range `since_token < s <= now_token`, where the two tokens are read from the
sync_result_builder.
This function is a first pass at generating the rooms part of the sync response.
It determines which rooms have changed during the sync period, and categorises
them into four buckets: "knock", "invite", "join" and "leave".
If there are too many events in that range to report, things get complicated.
In this situation we return a truncated list of the most recent events, and
indicate in the response that there is a "gap" of omitted events. Additionally:
1. Finds all membership changes for the user in the sync period (from
`since_token` up to `now_token`).
2. Uses those to place the room in one of the four categories above.
3. Builds a `_RoomChanges` struct to record this, and return that struct.
- we include a "state_delta", to describe the changes in state over the gap,
- we include all membership events applying to the user making the request,
even those in the gap.
See the spec for the rationale:
https://spec.matrix.org/v1.1/client-server-api/#syncing
For rooms classified as "knock", "invite" or "leave", we just need to report
a single membership event in the eventual /sync response. For "join" we need
to fetch additional non-membership events, e.g. messages in the room. That is
more complicated, so instead we report an intermediary `RoomSyncResultBuilder`
struct, and leave the additional work to `_generate_room_entry`.
The sync_result_builder is not modified by this function.
"""
@ -1685,16 +1686,6 @@ class SyncHandler:
assert since_token
# The spec
# https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
# notes that membership events need special consideration:
#
# > When a sync is limited, the server MUST return membership events for events
# > in the gap (between since and the start of the returned timeline), regardless
# > as to whether or not they are redundant.
#
# We fetch such events here, but we only seem to use them for categorising rooms
# as newly joined, newly left, invited or knocked.
# TODO: we've already called this function and ran this query in
# _have_rooms_changed. We could keep the results in memory to avoid a
# second query, at the cost of more complicated source code.
@ -2008,6 +1999,23 @@ class SyncHandler:
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
Ideally, we want to report all events whose stream ordering `s` lies in the
range `since_token < s <= now_token`, where the two tokens are read from the
sync_result_builder.
If there are too many events in that range to report, things get complicated.
In this situation we return a truncated list of the most recent events, and
indicate in the response that there is a "gap" of omitted events. Lots of this
is handled in `_load_filtered_recents`, but some of is handled in this method.
Additionally:
- we include a "state_delta", to describe the changes in state over the gap,
- we include all membership events applying to the user making the request,
even those in the gap.
See the spec for the rationale:
https://spec.matrix.org/v1.1/client-server-api/#syncing
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
@ -2037,7 +2045,7 @@ class SyncHandler:
since_token = room_builder.since_token
upto_token = room_builder.upto_token
with start_active_span("generate_room_entry"):
with start_active_span("sync.generate_room_entry"):
set_tag("room_id", room_id)
log_kv({"events": len(events or ())})
@ -2165,10 +2173,10 @@ class SyncHandler:
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
unread_notifications["notification_count"] = notifs.notify_count
unread_notifications["highlight_count"] = notifs.highlight_count
room_sync.unread_count = notifs["unread_count"]
room_sync.unread_count = notifs.unread_count
sync_result_builder.joined.append(room_sync)

View File

@ -13,9 +13,10 @@
# limitations under the License.
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import (
@ -37,7 +38,10 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomMember:
room_id: str
user_id: str
# How often we expect remote servers to resend us presence.
@ -119,7 +123,7 @@ class FollowerTypingHandler:
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
return member.user_id in self._room_typing.get(member.room_id, set())
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
@ -166,9 +170,9 @@ class FollowerTypingHandler:
for row in rows:
self._room_serials[row.room_id] = token
prev_typing = set(self._room_typing.get(row.room_id, []))
prev_typing = self._room_typing.get(row.room_id, set())
now_typing = set(row.user_ids)
self._room_typing[row.room_id] = row.user_ids
self._room_typing[row.room_id] = now_typing
if self.federation:
run_as_background_process(

View File

@ -148,9 +148,21 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet.
if self.pos is None:
return None
# If still None then the initial background update hasn't happened yet.
if self.pos is None:
return None
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos > room_max_stream_ordering:
# apparently, we've processed more events than exist in the database!
# this can happen if events are removed with history purge or similar.
logger.warning(
"Event stream ordering appears to have gone backwards (%i -> %i): "
"rewinding user directory processor",
self.pos,
room_max_stream_ordering,
)
self.pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date
while True:

View File

@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(504, msg)
@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
def redact_uri(uri):
def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with <redacted>"""
uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528
"""
def stopProducing(self):
def stopProducing(self) -> None:
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
def __init__(self, hs: "HomeServer", handler):
def __init__(
self,
hs: "HomeServer",
handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
super().__init__()
self._handler = handler
def _async_render(self, request: Request):
async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
return await self._handler(request)

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from http import HTTPStatus
from io import BytesIO
from typing import (
TYPE_CHECKING,
@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
e = SynapseError(
HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
)
return defer.fail(Failure(e))
return self._agent.request(
@ -586,7 +589,7 @@ class SimpleHttpClient:
if headers:
actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
body = await self.get_raw(uri, args, headers=actual_headers)
return json_decoder.decode(body.decode("utf-8"))
async def put_json(
@ -720,7 +723,9 @@ class SimpleHttpClient:
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
raise SynapseError(
HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
@ -732,12 +737,14 @@ class SimpleHttpClient:
)
except BodyExceededMaxSize:
raise SynapseError(
502,
HTTPStatus.BAD_GATEWAY,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
raise SynapseError(
HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
) from e
return (
length,

View File

@ -25,6 +25,7 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import (
IProtocol,
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
def connect(
self, protocol_factory: IProtocolFactory
) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
first_exception = None
server_list = await self._resolve_server()

View File

@ -19,6 +19,7 @@ import random
import sys
import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
request.destination,
msg,
)
raise SynapseError(502, msg, Codes.TOO_LARGE)
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",

View File

@ -14,7 +14,6 @@
# limitations under the License.
import abc
import collections
import html
import logging
import types
@ -30,12 +29,14 @@ from typing import (
Iterable,
Iterator,
List,
NoReturn,
Optional,
Pattern,
Tuple,
Union,
)
import attr
import jinja2
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
@ -57,12 +58,14 @@ from synapse.api.errors import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import opentracing
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -170,7 +173,9 @@ def return_html_error(
respond_with_html(request, code, body)
def wrap_async_request_handler(h):
def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
@ -183,7 +188,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""
async def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest
) -> None:
with request.processing():
await h(self, request)
@ -240,18 +247,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling.
"""
def __init__(self, extract_context=False):
def __init__(self, extract_context: bool = False):
super().__init__()
self._extract_context = extract_context
def render(self, request):
def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest):
async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@ -271,7 +278,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
f = failure.Failure()
self._send_error_response(f, request)
async def _async_render(self, request: Request):
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@ -318,7 +325,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
def __init__(self, canonical_json=False, extract_context=False):
def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context)
self.canonical_json = canonical_json
@ -327,7 +334,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
@ -347,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource):
return_json_error(f, request)
_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PathEntry:
pattern: Pattern
callback: ServletCallback
servlet_classname: str
class JsonResource(DirectServeJsonResource):
@ -368,34 +377,45 @@ class JsonResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
def __init__(
self,
hs: "HomeServer",
canonical_json: bool = True,
extract_context: bool = False,
):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs
def register_paths(self, method, path_patterns, callback, servlet_classname):
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
method (str): GET, POST etc
method: GET, POST etc
path_patterns (Iterable[str]): A list of regular expressions to which
the request URLs are compared.
path_patterns: A list of regular expressions to which the request
URLs are compared.
callback (function): The handler for the request. Usually a Servlet
callback: The handler for the request. Usually a Servlet
servlet_classname (str): The name of the handler to be used in prometheus
servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
method = method.encode("utf-8") # method is bytes on py3
method_bytes = method.encode("utf-8")
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname)
)
@ -427,7 +447,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
async def _async_render(self, request):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appropriate name for this handler in prometheus
@ -468,7 +488,7 @@ class DirectServeHtmlResource(_AsyncResource):
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
@ -492,12 +512,12 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection.
"""
def render_GET(self, request: Request):
def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request)
return super().render_GET(request)
def _unrecognised_request_handler(request):
def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests
This is a request handler suitable for return from
@ -505,7 +525,7 @@ def _unrecognised_request_handler(request):
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
request: Unused, but passed in to match the signature of ServletCallback.
"""
raise UnrecognizedRequestError()
@ -513,23 +533,23 @@ def _unrecognised_request_handler(request):
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
def __init__(self, path):
resource.Resource.__init__(self)
def __init__(self, path: str):
super().__init__()
self.url = path
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request):
def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0:
return self # select ourselves as the child to render
return resource.Resource.getChild(self, name, request)
return super().getChild(name, request)
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@ -537,10 +557,10 @@ class OptionsResource(resource.Resource):
return b""
def getChildWithDefault(self, path, request):
def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
return resource.Resource.getChildWithDefault(self, path, request)
return super().getChildWithDefault(path, request)
class RootOptionsRedirectResource(OptionsResource, RootRedirect):
@ -649,7 +669,7 @@ def respond_with_json(
json_object: Any,
send_cors: bool = False,
canonical_json: bool = True,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@ -696,7 +716,7 @@ def respond_with_json_bytes(
code: int,
json_bytes: bytes,
send_cors: bool = False,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@ -713,7 +733,7 @@ def respond_with_json_bytes(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@ -731,7 +751,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
) -> None:
"""Encodes the given JSON object on a thread and then writes it to the
request.
@ -743,7 +763,20 @@ async def _async_write_json_to_request_in_thread(
expensive.
"""
json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes:
# it might take a while for the threadpool to schedule us, so we write
# opentracing logs once we actually get scheduled, so that we can see how
# much that contributed.
if opentracing_span:
opentracing_span.log_kv({"event": "scheduled"})
res = json_encoder(json_object)
if opentracing_span:
opentracing_span.log_kv({"event": "encoded"})
return res
with start_active_span("encode_json_response"):
span = active_span()
json_str = await defer_to_thread(request.reactor, encode, span)
_write_bytes_to_request(request, json_str)
@ -773,7 +806,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
def set_cors_headers(request: Request):
def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@ -790,14 +823,14 @@ def set_cors_headers(request: Request):
)
def respond_with_html(request: Request, code: int, html: str):
def respond_with_html(request: Request, code: int, html: str) -> None:
"""
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
"""
respond_with_html_bytes(request, code, html.encode("utf-8"))
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
"""
Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
@ -815,7 +848,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@ -828,7 +861,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request)
def set_clickjacking_protection_headers(request: Request):
def set_clickjacking_protection_headers(request: Request) -> None:
"""
Set headers to guard against clickjacking of embedded content.
@ -850,7 +883,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request)
def finish_request(request: Request):
def finish_request(request: Request) -> None:
"""Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the

View File

@ -14,6 +14,7 @@
""" This module contains base REST classes for constructing REST servlets. """
import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
@ -30,6 +31,7 @@ from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder
@ -137,11 +139,15 @@ def parse_integer_from_args(
return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
)
else:
return default
@ -246,11 +252,15 @@ def parse_boolean_from_args(
message = (
"Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
)
else:
return default
@ -313,7 +323,7 @@ def parse_bytes_from_args(
return args[name_bytes][0]
elif required:
message = "Missing string query parameter %s" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
return default
@ -407,14 +417,16 @@ def _parse_string_value(
try:
value_str = value.decode(encoding)
except ValueError:
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
)
if allowed_values is not None and value_str not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
else:
return value_str
@ -510,7 +522,9 @@ def parse_strings_from_args(
else:
if required:
message = "Missing string query parameter %r" % (name,)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
)
return default
@ -638,7 +652,7 @@ def parse_json_value_from_request(
try:
content_bytes = request.content.read() # type: ignore
except Exception:
raise SynapseError(400, "Error reading JSON content.")
raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
@ -647,7 +661,9 @@ def parse_json_value_from_request(
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
)
return content
@ -673,7 +689,7 @@ def parse_json_object_from_request(
if not isinstance(content, dict):
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
return content
@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
)
class RestServlet:
@ -709,7 +727,7 @@ class RestServlet:
into the appropriate HTTP response.
"""
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
@ -758,10 +776,12 @@ class ResolveRoomIdMixin:
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
HTTPStatus.BAD_REQUEST,
"%s was not legal room ID or room alias" % (room_identifier,),
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
HTTPStatus.BAD_REQUEST,
"Unknown room ID or room alias %s" % room_identifier,
)
return resolved_room_id, remote_room_hosts

View File

@ -14,7 +14,7 @@
import contextlib
import logging
import time
from typing import Generator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
@ -35,6 +35,9 @@ from synapse.logging.context import (
)
from synapse.types import Requester
if TYPE_CHECKING:
import opentracing
logger = logging.getLogger(__name__)
_next_request_seq = 0
@ -66,9 +69,9 @@ class SynapseRequest(Request):
self,
channel: HTTPChannel,
site: "SynapseSite",
*args,
*args: Any,
max_request_body_size: int = 1024,
**kw,
**kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
@ -81,6 +84,10 @@ class SynapseRequest(Request):
# server name, for client requests this is the Requester object.
self._requester: Optional[Union[Requester, str]] = None
# An opentracing span for this request. Will be closed when the request is
# completely processed.
self._opentracing_span: "Optional[opentracing.Span]" = None
# we can't yet create the logcontext, as we don't know the method.
self.logcontext: Optional[LoggingContext] = None
@ -148,6 +155,13 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def set_opentracing_span(self, span: "opentracing.Span") -> None:
"""attach an opentracing span to this request
Doing so will cause the span to be closed when we finish processing the request
"""
self._opentracing_span = span
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
@ -286,6 +300,9 @@ class SynapseRequest(Request):
self._processing_finished_time = time.time()
self._is_processing = False
if self._opentracing_span:
self._opentracing_span.log_kv({"event": "finished processing"})
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
@ -299,6 +316,8 @@ class SynapseRequest(Request):
"""
self.finish_time = time.time()
Request.finish(self)
if self._opentracing_span:
self._opentracing_span.log_kv({"event": "response sent"})
if not self._is_processing:
assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
@ -333,6 +352,11 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
logger.info("Connection from client lost before response was sent")
if self._opentracing_span:
self._opentracing_span.log_kv(
{"event": "client connection lost", "reason": str(reason.value)}
)
if not self._is_processing:
self._finished_processing()
@ -421,6 +445,10 @@ class SynapseRequest(Request):
usage.evt_db_fetch_count,
)
# complete the opentracing span, if any.
if self._opentracing_span:
self._opentracing_span.finish()
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
@ -557,7 +585,7 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued: bool) -> Request:
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,

View File

@ -22,20 +22,33 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
import inspect
import logging
import threading
import typing
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@ -55,7 +68,6 @@ try:
def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return resource.getrusage(RUSAGE_THREAD)
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage.
@ -66,7 +78,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
def logcontext_error(msg: str):
def logcontext_error(msg: str) -> None:
logger.warning(msg)
@ -223,22 +235,19 @@ class _Sentinel:
def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
def start(self, rusage: "Optional[resource.struct_rusage]"):
def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
def stop(self, rusage: "Optional[resource.struct_rusage]"):
def add_database_transaction(self, duration_sec: float) -> None:
pass
def add_database_transaction(self, duration_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
pass
def __bool__(self) -> Literal[False]:
@ -379,7 +388,12 @@ class LoggingContext:
)
return self
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@ -399,17 +413,6 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
# we track the current request
record.request = self.request
# we also track the current scope:
record.scope = self.scope
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@ -626,7 +629,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@ -711,16 +719,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
def preserve_fn(f):
R = TypeVar("R")
@overload
def preserve_fn( # type: ignore[misc]
f: Callable[..., Awaitable[R]],
) -> Callable[..., "defer.Deferred[R]"]:
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
...
def preserve_fn(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
]
) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs)
return g
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
@overload
def run_in_background( # type: ignore[misc]
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def run_in_background(
f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
...
def run_in_background(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@ -751,6 +804,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)
return defer.succeed(res)
if res.called and not res.paused:
@ -778,13 +835,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res
def make_deferred_yieldable(deferred):
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
T = TypeVar("T")
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
result/failure).
def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Given a deferred, make it follow the Synapse logcontext rules:
If the deferred has completed, essentially does nothing (just returns another
completed deferred with the result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
@ -792,16 +850,6 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred):
return deferred
if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery.
@ -823,7 +871,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
def defer_to_thread(reactor, f, *args, **kwargs):
def defer_to_thread(
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@ -855,7 +905,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
def defer_to_threadpool(
reactor: "ISynapseReactor",
threadpool: ThreadPool,
f: Callable[..., R],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@ -897,7 +953,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
def g():
def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)

View File

@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ
import attr
from twisted.internet import defer
from twisted.web.http import Request
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
@ -219,11 +220,12 @@ class _DummyTagNames:
try:
import opentracing
import opentracing.tags
tags = opentracing.tags
except ImportError:
opentracing = None
tags = _DummyTagNames
opentracing = None # type: ignore[assignment]
tags = _DummyTagNames # type: ignore[assignment]
try:
from jaeger_client import Config as JaegerConfig
@ -366,7 +368,7 @@ def init_tracer(hs: "HomeServer"):
global opentracing
if not hs.config.tracing.opentracer_enabled:
# We don't have a tracer
opentracing = None
opentracing = None # type: ignore[assignment]
return
if not opentracing or not JaegerConfig:
@ -452,7 +454,7 @@ def start_active_span(
"""
if opentracing is None:
return noop_context_manager()
return noop_context_manager() # type: ignore[unreachable]
return opentracing.tracer.start_active_span(
operation_name,
@ -477,7 +479,7 @@ def start_active_span_follows_from(
forced, the new span will also have tracing forced.
"""
if opentracing is None:
return noop_context_manager()
return noop_context_manager() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(operation_name, references=references)
@ -490,48 +492,6 @@ def start_active_span_follows_from(
return scope
def start_active_span_from_request(
request,
operation_name,
references=None,
tags=None,
start_time=None,
ignore_active_span=False,
finish_on_close=True,
):
"""
Extracts a span context from a Twisted Request.
args:
headers (twisted.web.http.Request)
For the other args see opentracing.tracer
returns:
span_context (opentracing.span.SpanContext)
"""
# Twisted encodes the values as lists whereas opentracing doesn't.
# So, we take the first item in the list.
# Also, twisted uses byte arrays while opentracing expects strings.
if opentracing is None:
return noop_context_manager()
header_dict = {
k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
}
context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
return opentracing.tracer.start_active_span(
operation_name,
child_of=context,
references=references,
tags=tags,
start_time=start_time,
ignore_active_span=ignore_active_span,
finish_on_close=finish_on_close,
)
def start_active_span_from_edu(
edu_content,
operation_name,
@ -553,7 +513,7 @@ def start_active_span_from_edu(
references = references or []
if opentracing is None:
return noop_context_manager()
return noop_context_manager() # type: ignore[unreachable]
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
@ -594,18 +554,21 @@ def active_span():
@ensure_active_span("set a tag")
def set_tag(key, value):
"""Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
def log_kv(key_values, timestamp=None):
"""Log to the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
def set_operation_name(operation_name):
"""Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@ -674,6 +637,7 @@ def inject_header_dict(
span = opentracing.tracer.active_span
carrier: Dict[str, str] = {}
assert span is not None
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@ -716,6 +680,7 @@ def get_active_span_text_map(destination=None):
return {}
carrier: Dict[str, str] = {}
assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@ -731,12 +696,27 @@ def active_span_context_as_string():
"""
carrier: Dict[str, str] = {}
if opentracing:
assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
return json_encoder.encode(carrier)
def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
"""Extract an opentracing context from the headers on an HTTP request
This is useful when we have received an HTTP request from another part of our
system, and want to link our spans to those of the remote system.
"""
if not opentracing:
return None
header_dict = {
k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
}
return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
@only_if_tracing
def span_context_from_string(carrier):
"""
@ -773,7 +753,7 @@ def trace(func=None, opname=None):
def decorator(func):
if opentracing is None:
return func
return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
@ -864,7 +844,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
"""
if opentracing is None:
yield
yield # type: ignore[unreachable]
return
request_tags = {
@ -876,10 +856,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
}
request_name = request.request_metrics.name
if extract_context:
scope = start_active_span_from_request(request, request_name)
else:
scope = start_active_span(request_name)
context = span_context_from_request(request) if extract_context else None
# we configure the scope not to finish the span immediately on exit, and instead
# pass the span into the SynapseRequest, which will finish it once we've finished
# sending the response to the client.
scope = start_active_span(request_name, child_of=context, finish_on_close=False)
request.set_opentracing_span(scope.span)
with scope:
inject_response_headers(request.responseHeaders)

View File

@ -71,7 +71,7 @@ class LogContextScopeManager(ScopeManager):
if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)
return Scope(None, span) # type: ignore[arg-type]
elif ctx.scope is not None:
# We want the logging scope to look exactly the same so we give it
# a blank suffix

View File

@ -13,7 +13,6 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import (
Awaitable,
Callable,
@ -44,7 +43,13 @@ from synapse.logging.opentracing import log_kv, start_active_span
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import PersistedEventPosition, RoomStreamToken, StreamToken, UserID
from synapse.types import (
JsonDict,
PersistedEventPosition,
RoomStreamToken,
StreamToken,
UserID,
)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@ -178,7 +183,12 @@ class _NotifierUserStream:
return _NotificationListener(self.notify_deferred.observe())
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventStreamResult:
events: List[Union[JsonDict, EventBase]]
start_token: StreamToken
end_token: StreamToken
def __bool__(self):
return bool(self.events)
@ -582,9 +592,12 @@ class Notifier:
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if after_token == before_token:
return EventStreamResult([], (from_token, from_token))
return EventStreamResult([], from_token, from_token)
events: List[EventBase] = []
# The events fetched from each source are a JsonDict, EventBase, or
# UserPresenceState, but see below for UserPresenceState being
# converted to JsonDict.
events: List[Union[JsonDict, EventBase]] = []
end_token = from_token
for name, source in self.event_sources.sources.get_sources():
@ -623,7 +636,7 @@ class Notifier:
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
return EventStreamResult(events, (from_token, end_token))
return EventStreamResult(events, from_token, end_token)
user_id_for_stream = user.to_string()
if is_peeking:

View File

@ -177,12 +177,12 @@ class EmailPusher(Pusher):
return
for push_action in unprocessed:
received_at = push_action["received_ts"]
received_at = push_action.received_ts
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
room_ready_at = self.room_ready_to_notify_at(push_action.room_id)
should_notify_at = max(notif_ready_at, room_ready_at)
@ -193,23 +193,23 @@ class EmailPusher(Pusher):
# to be delivered.
reason: EmailReason = {
"room_id": push_action["room_id"],
"room_id": push_action.room_id,
"now": self.clock.time_msec(),
"received_at": received_at,
"delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
"last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
"throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
"last_sent_ts": self.get_room_last_sent_ts(push_action.room_id),
"throttle_ms": self.get_room_throttle_ms(push_action.room_id),
}
await self.send_notification(unprocessed, reason)
await self.save_last_stream_ordering_and_success(
max(ea["stream_ordering"] for ea in unprocessed)
max(ea.stream_ordering for ea in unprocessed)
)
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
await self.sent_notif_update_throttle(ea["room_id"], ea)
await self.sent_notif_update_throttle(ea.room_id, ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@ -284,10 +284,10 @@ class EmailPusher(Pusher):
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
notified_push_action["stream_ordering"]
notified_push_action.stream_ordering
)
time_of_this_notifs = notified_push_action["received_ts"]
time_of_this_notifs = notified_push_action.received_ts
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs

View File

@ -192,7 +192,7 @@ class HttpPusher(Pusher):
"http-push",
tags={
"authenticated_entity": self.user_id,
"event_id": push_action["event_id"],
"event_id": push_action.event_id,
"app_id": self.app_id,
"app_display_name": self.app_display_name,
},
@ -202,7 +202,7 @@ class HttpPusher(Pusher):
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
self.last_stream_ordering = push_action.stream_ordering
pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
@ -245,7 +245,7 @@ class HttpPusher(Pusher):
self.pushkey,
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
self.last_stream_ordering = push_action.stream_ordering
await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
@ -268,17 +268,17 @@ class HttpPusher(Pusher):
break
async def _process_one(self, push_action: HttpPushAction) -> bool:
if "notify" not in push_action["actions"]:
if "notify" not in push_action.actions:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
event = await self.store.get_event(push_action["event_id"], allow_none=True)
event = await self.store.get_event(push_action.event_id, allow_none=True)
if event is None:
return True # It's been redacted
rejected = await self.dispatch_push(event, tweaks, badge)

View File

@ -232,15 +232,13 @@ class Mailer:
reason: The notification that was ready and is the cause of an email
being sent.
"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions])
notif_events = await self.store.get_events(
[pa["event_id"] for pa in push_actions]
)
notif_events = await self.store.get_events([pa.event_id for pa in push_actions])
notifs_by_room: Dict[str, List[EmailPushAction]] = {}
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
notifs_by_room.setdefault(pa.room_id, []).append(pa)
# collect the current state for all the rooms in which we have
# notifications
@ -264,7 +262,7 @@ class Mailer:
await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0))
rooms: List[RoomVars] = []
@ -356,7 +354,7 @@ class Mailer:
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
ev = notif_events[n["event_id"]]
ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE:
is_invite = True
@ -376,7 +374,7 @@ class Mailer:
if not is_invite:
for n in notifs:
notifvars = await self._get_notif_vars(
n, user_id, notif_events[n["event_id"]], room_state_ids
n, user_id, notif_events[n.event_id], room_state_ids
)
# merge overlapping notifs together.
@ -444,15 +442,15 @@ class Mailer:
"""
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
notif.room_id,
notif.event_id,
before_limit=CONTEXT_BEFORE,
after_limit=CONTEXT_AFTER,
)
ret: NotifVars = {
"link": self._make_notif_link(notif),
"ts": notif["received_ts"],
"ts": notif.received_ts,
"messages": [],
}
@ -516,7 +514,7 @@ class Mailer:
ret: MessageVars = {
"event_type": event.type,
"is_historical": event.event_id != notif["event_id"],
"is_historical": event.event_id != notif.event_id,
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
@ -610,7 +608,7 @@ class Mailer:
# See if one of the notifs is an invite event for the user
invite_event = None
for n in notifs:
ev = notif_events[n["event_id"]]
ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE:
invite_event = ev
@ -659,7 +657,7 @@ class Mailer:
if len(notifs) == 1:
# There is just the one notification, so give some detail
sender_name = None
event = notif_events[notifs[0]["event_id"]]
event = notif_events[notifs[0].event_id]
if ("m.room.member", event.sender) in room_state_ids:
state_event_id = room_state_ids[("m.room.member", event.sender)]
state_event = await self.store.get_event(state_event_id)
@ -753,9 +751,9 @@ class Mailer:
# are already in descending received_ts.
sender_ids = {}
for n in notifs:
sender = notif_events[n["event_id"]].sender
sender = notif_events[n.event_id].sender
if sender not in sender_ids:
sender_ids[sender] = n["event_id"]
sender_ids[sender] = n.event_id
# Get the actual member events (in order to calculate a pretty name for
# the room).
@ -830,17 +828,17 @@ class Mailer:
if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email.email_riot_base_url,
notif["room_id"],
notif["event_id"],
notif.room_id,
notif.event_id,
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
notif["room_id"],
notif["event_id"],
notif.room_id,
notif.event_id,
)
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id)
def _make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str

View File

@ -17,9 +17,10 @@ import logging
import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from matrix_common.regex import glob_to_regex, to_word_pattern
from synapse.events import EventBase
from synapse.types import JsonDict, UserID
from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent:
r = regex_cache.get((display_name, False, True), None)
if not r:
r1 = re.escape(display_name)
r1 = re_word_boundary(r1)
r1 = to_word_pattern(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
try:
r = regex_cache.get((glob, True, word_boundary), None)
if not r:
r = glob_to_regex(glob, word_boundary)
r = glob_to_regex(glob, word_boundary=word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return bool(r.search(value))
except re.error:

View File

@ -13,6 +13,7 @@
# limitations under the License.
from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
badge = len(invites)
@ -36,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
room_id, user_id, last_unread_event_id
)
)
if notifs["notify_count"] == 0:
if notifs.notify_count == 0:
continue
if group_by_room:
@ -44,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
badge += notifs["notify_count"]
badge += notifs.notify_count
return badge

View File

@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -113,7 +114,9 @@ class PusherPool:
"""
if kind == "email":
email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
email_owner = await self.store.get_user_id_by_threepid(
"email", canonicalise_email(pushkey)
)
if email_owner != user_id:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)

View File

@ -88,6 +88,7 @@ REQUIREMENTS = [
# with the latest security patches.
"cryptography>=3.4.7",
"ijson>=3.1",
"matrix-common==1.0.0",
]
CONDITIONAL_REQUIREMENTS = {

View File

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[

View File

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache
@ -25,7 +25,12 @@ if TYPE_CHECKING:
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache(

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -27,7 +27,12 @@ if TYPE_CHECKING:
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.hs = hs

View File

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@ -75,12 +80,3 @@ class SlavedEventStore(
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()

View File

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
@ -24,7 +24,12 @@ if TYPE_CHECKING:
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -26,7 +26,12 @@ if TYPE_CHECKING:
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.hs = hs

View File

@ -15,7 +15,6 @@
import heapq
import logging
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
@ -30,6 +29,7 @@ from typing import (
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -226,17 +226,14 @@ class BackfillStream(Stream):
or it went from being an outlier to not.
"""
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class BackfillStreamRow:
event_id: str
room_id: str
type: str
state_key: Optional[str]
redacts: Optional[str]
relates_to: Optional[str]
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
@ -256,18 +253,15 @@ class BackfillStream(Stream):
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: str
currently_active: bool
NAME = "presence"
ROW_TYPE = PresenceStreamRow
@ -302,7 +296,7 @@ class PresenceFederationStream(Stream):
send.
"""
@attr.s(slots=True, auto_attribs=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceFederationStreamRow:
destination: str
user_id: str
@ -320,9 +314,10 @@ class PresenceFederationStream(Stream):
class TypingStream(Stream):
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TypingStreamRow:
room_id: str
user_ids: List[str]
NAME = "typing"
ROW_TYPE = TypingStreamRow
@ -348,16 +343,13 @@ class TypingStream(Stream):
class ReceiptsStream(Stream):
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
receipt_type: str
user_id: str
event_id: str
data: dict
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@ -374,7 +366,9 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
"""A user has changed their push rules"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PushRulesStreamRow:
user_id: str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@ -396,10 +390,12 @@ class PushRulesStream(Stream):
class PushersStream(Stream):
"""A user has added/changed/removed a pusher"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PushersStreamRow:
user_id: str
app_id: str
pushkey: str
deleted: bool
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@ -419,7 +415,7 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s(slots=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
@ -430,9 +426,9 @@ class CachesStream(Stream):
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
cache_func: str
keys: Optional[List[Any]]
invalidation_ts: int
NAME = "caches"
ROW_TYPE = CachesStreamRow
@ -451,9 +447,9 @@ class DeviceListsStream(Stream):
told about a device update.
"""
@attr.s(slots=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity = attr.ib(type=str)
entity: str
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
@ -470,7 +466,9 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ToDeviceStreamRow:
entity: str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@ -487,9 +485,11 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TagAccountDataStreamRow:
user_id: str
room_id: str
data: JsonDict
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@ -506,10 +506,11 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
AccountDataStreamRow = namedtuple(
"AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class AccountDataStreamRow:
user_id: str
room_id: Optional[str]
data_type: str
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@ -573,10 +574,12 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class GroupsStreamRow:
group_id: str
user_id: str
type: str
content: JsonDict
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@ -593,7 +596,9 @@ class GroupServerStream(Stream):
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserSignatureStreamRow:
user_id: str
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow

View File

@ -12,14 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
import attr
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -30,13 +32,10 @@ class FederationStream(Stream):
sending disabled.
"""
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FederationStreamRow:
type: str # the type of data as defined in the BaseFederationRows
data: JsonDict # serialization of a federation.send_queue.BaseFederationRow
NAME = "federation"
ROW_TYPE = FederationStreamRow

View File

@ -69,6 +69,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
from synapse.rest.admin.users import (
AccountDataRestServlet,
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
@ -108,7 +109,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
)
def __init__(self, hs: "HomeServer"):
@ -195,7 +196,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
AccountDataRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)

View File

@ -22,7 +22,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates"""
PATTERNS = admin_patterns("/background_updates/start_job")
PATTERNS = admin_patterns("/background_updates/start_job$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"])

View File

@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str, device_id: str
@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
target_user.to_string(), device_id
)
if device is None:
raise NotFoundError("No device found")
return HTTPStatus.OK, device
async def on_DELETE(
@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"):
"""
Args:
hs: server
"""
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, user_id: str
@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())

View File

@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()

View File

@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
200 OK with details of a destination if success otherwise an error.
"""
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups"""
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()

View File

@ -17,7 +17,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = [
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
*admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
# This path kept around for legacy reasons
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
*admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
]
def __init__(self, hs: "HomeServer"):
@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
"/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
"/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined."""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info("Protecting local media by ID: %s", media_id)
@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
class UnprotectMediaByID(RestServlet):
"""Unprotect local media from being quarantined."""
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info("Unprotecting local media by ID: %s", media_id)
@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room."""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
await assert_requester_is_admin(self.auth, request)
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet):
"""Delete local media by a given ID. Removes it from this server."""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
media that exist given for this user
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")

Some files were not shown because too many files have changed in this diff Show More