mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-17 09:03:57 -05:00
Merge remote-tracking branch 'upstream/release-v1.30.0'
This commit is contained in:
commit
8b753230af
102 changed files with 2693 additions and 933 deletions
8
.git-blame-ignore-revs
Normal file
8
.git-blame-ignore-revs
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Black reformatting (#5482).
|
||||||
|
32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
|
||||||
|
|
||||||
|
# Target Python 3.5 with black (#8664).
|
||||||
|
aff1eb7c671b0a3813407321d2702ec46c71fa56
|
||||||
|
|
||||||
|
# Update black to 20.8b1 (#9381).
|
||||||
|
0a00b7ff14890987f09112a2ae696c61001e6cf1
|
||||||
77
CHANGES.md
77
CHANGES.md
|
|
@ -1,3 +1,80 @@
|
||||||
|
Synapse 1.30.0rc1 (2021-03-16)
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Note that this release deprecates the ability for appservices to
|
||||||
|
call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice
|
||||||
|
developers should use a `type` value of `m.login.application_service` as
|
||||||
|
per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions).
|
||||||
|
In future releases, calling this endpoint with an access token - but without a `m.login.application_service`
|
||||||
|
type - will fail.
|
||||||
|
|
||||||
|
Features
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573))
|
||||||
|
- Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
|
||||||
|
- Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549))
|
||||||
|
- Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601))
|
||||||
|
- Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617))
|
||||||
|
- Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626))
|
||||||
|
|
||||||
|
|
||||||
|
Bugfixes
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473))
|
||||||
|
- Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583))
|
||||||
|
- Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567))
|
||||||
|
- Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587))
|
||||||
|
- Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597))
|
||||||
|
- Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620))
|
||||||
|
- Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623))
|
||||||
|
|
||||||
|
|
||||||
|
Updates to the Docker image
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
- Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553))
|
||||||
|
|
||||||
|
|
||||||
|
Improved Documentation
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
- Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508))
|
||||||
|
- Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550))
|
||||||
|
- Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571))
|
||||||
|
- Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580))
|
||||||
|
- Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
|
||||||
|
|
||||||
|
|
||||||
|
Deprecations and Removals
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
- The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540))
|
||||||
|
- Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559))
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458))
|
||||||
|
- Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520))
|
||||||
|
- Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523))
|
||||||
|
- Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618))
|
||||||
|
- Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541))
|
||||||
|
- Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560))
|
||||||
|
- Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561))
|
||||||
|
- Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562))
|
||||||
|
- Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563))
|
||||||
|
- Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568))
|
||||||
|
- Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576))
|
||||||
|
- Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586))
|
||||||
|
- Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590))
|
||||||
|
- Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596))
|
||||||
|
- Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604))
|
||||||
|
- Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619))
|
||||||
|
|
||||||
|
|
||||||
Synapse 1.29.0 (2021-03-08)
|
Synapse 1.29.0 (2021-03-08)
|
||||||
===========================
|
===========================
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,10 @@ recursive-include scripts *
|
||||||
recursive-include scripts-dev *
|
recursive-include scripts-dev *
|
||||||
recursive-include synapse *.pyi
|
recursive-include synapse *.pyi
|
||||||
recursive-include tests *.py
|
recursive-include tests *.py
|
||||||
include tests/http/ca.crt
|
recursive-include tests *.pem
|
||||||
include tests/http/ca.key
|
recursive-include tests *.p8
|
||||||
include tests/http/server.key
|
recursive-include tests *.crt
|
||||||
|
recursive-include tests *.key
|
||||||
|
|
||||||
recursive-include synapse/res *
|
recursive-include synapse/res *
|
||||||
recursive-include synapse/static *.css
|
recursive-include synapse/static *.css
|
||||||
|
|
|
||||||
|
|
@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
|
||||||
It is recommended to put a reverse proxy such as
|
It is recommended to put a reverse proxy such as
|
||||||
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
|
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
|
||||||
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
|
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
|
||||||
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
|
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
|
||||||
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
|
`HAProxy <https://www.haproxy.org/>`_ or
|
||||||
|
`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
|
||||||
doing so is that it means that you can expose the default https port (443) to
|
doing so is that it means that you can expose the default https port (443) to
|
||||||
Matrix clients without needing to run Synapse with root privileges.
|
Matrix clients without needing to run Synapse with root privileges.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
|
||||||
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
|
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
|
||||||
"ACS location" (also known as "allowed callback URLs") at the identity provider.
|
"ACS location" (also known as "allowed callback URLs") at the identity provider.
|
||||||
|
|
||||||
|
The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
|
||||||
|
``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
|
||||||
|
provider uses this property to validate or otherwise identify Synapse, its configuration
|
||||||
|
will need to be updated to use the new URL. Alternatively you could create a new, separate
|
||||||
|
"EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
|
||||||
|
the existing "EntityDescriptor" as they were.
|
||||||
|
|
||||||
Changes to HTML templates
|
Changes to HTML templates
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
libpq5 \
|
libpq5 \
|
||||||
libwebp6 \
|
libwebp6 \
|
||||||
xmlsec1 \
|
xmlsec1 \
|
||||||
|
libjemalloc2 \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
COPY --from=builder /install /usr/local
|
||||||
|
|
|
||||||
|
|
@ -204,3 +204,8 @@ healthcheck:
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Using jemalloc
|
||||||
|
|
||||||
|
Jemalloc is embedded in the image and will be used instead of the default allocator.
|
||||||
|
You can read about jemalloc by reading the Synapse [README](../README.md)
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import codecs
|
import codecs
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
@ -213,6 +214,13 @@ def main(args, environ):
|
||||||
if "-m" not in args:
|
if "-m" not in args:
|
||||||
args = ["-m", synapse_worker] + args
|
args = ["-m", synapse_worker] + args
|
||||||
|
|
||||||
|
jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),)
|
||||||
|
|
||||||
|
if os.path.isfile(jemallocpath):
|
||||||
|
environ["LD_PRELOAD"] = jemallocpath
|
||||||
|
else:
|
||||||
|
log("Could not find %s, will not use" % (jemallocpath,))
|
||||||
|
|
||||||
# if there are no config files passed to synapse, try adding the default file
|
# if there are no config files passed to synapse, try adding the default file
|
||||||
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
|
if not any(p.startswith("--config-path") or p.startswith("-c") for p in args):
|
||||||
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
|
config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data")
|
||||||
|
|
@ -248,9 +256,9 @@ running with 'migrate_config'. See the README for more details.
|
||||||
args = ["python"] + args
|
args = ["python"] + args
|
||||||
if ownership is not None:
|
if ownership is not None:
|
||||||
args = ["gosu", ownership] + args
|
args = ["gosu", ownership] + args
|
||||||
os.execv("/usr/sbin/gosu", args)
|
os.execve("/usr/sbin/gosu", args, environ)
|
||||||
else:
|
else:
|
||||||
os.execv("/usr/local/bin/python", args)
|
os.execve("/usr/local/bin/python", args, environ)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
# Contents
|
# Contents
|
||||||
- [List all media in a room](#list-all-media-in-a-room)
|
- [Querying media](#querying-media)
|
||||||
|
* [List all media in a room](#list-all-media-in-a-room)
|
||||||
|
* [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
|
||||||
- [Quarantine media](#quarantine-media)
|
- [Quarantine media](#quarantine-media)
|
||||||
* [Quarantining media by ID](#quarantining-media-by-id)
|
* [Quarantining media by ID](#quarantining-media-by-id)
|
||||||
* [Quarantining media in a room](#quarantining-media-in-a-room)
|
* [Quarantining media in a room](#quarantining-media-in-a-room)
|
||||||
|
|
@ -10,7 +12,11 @@
|
||||||
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
|
* [Delete local media by date or size](#delete-local-media-by-date-or-size)
|
||||||
- [Purge Remote Media API](#purge-remote-media-api)
|
- [Purge Remote Media API](#purge-remote-media-api)
|
||||||
|
|
||||||
# List all media in a room
|
# Querying media
|
||||||
|
|
||||||
|
These APIs allow extracting media information from the homeserver.
|
||||||
|
|
||||||
|
## List all media in a room
|
||||||
|
|
||||||
This API gets a list of known media in a room.
|
This API gets a list of known media in a room.
|
||||||
However, it only shows media from unencrypted events or rooms.
|
However, it only shows media from unencrypted events or rooms.
|
||||||
|
|
@ -36,6 +42,12 @@ The API returns a JSON body like the following:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## List all media uploaded by a user
|
||||||
|
|
||||||
|
Listing all media that has been uploaded by a local user can be achieved through
|
||||||
|
the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user)
|
||||||
|
Admin API.
|
||||||
|
|
||||||
# Quarantine media
|
# Quarantine media
|
||||||
|
|
||||||
Quarantining media means that it is marked as inaccessible by users. It applies
|
Quarantining media means that it is marked as inaccessible by users. It applies
|
||||||
|
|
|
||||||
|
|
@ -226,7 +226,7 @@ Synapse config:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: github
|
- idp_id: github
|
||||||
idp_name: Github
|
idp_name: Github
|
||||||
idp_brand: "org.matrix.github" # optional: styling hint for clients
|
idp_brand: "github" # optional: styling hint for clients
|
||||||
discover: false
|
discover: false
|
||||||
issuer: "https://github.com/"
|
issuer: "https://github.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
@ -252,7 +252,7 @@ oidc_providers:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: google
|
- idp_id: google
|
||||||
idp_name: Google
|
idp_name: Google
|
||||||
idp_brand: "org.matrix.google" # optional: styling hint for clients
|
idp_brand: "google" # optional: styling hint for clients
|
||||||
issuer: "https://accounts.google.com/"
|
issuer: "https://accounts.google.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
|
@ -299,7 +299,7 @@ Synapse config:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: gitlab
|
- idp_id: gitlab
|
||||||
idp_name: Gitlab
|
idp_name: Gitlab
|
||||||
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
|
idp_brand: "gitlab" # optional: styling hint for clients
|
||||||
issuer: "https://gitlab.com/"
|
issuer: "https://gitlab.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
|
@ -334,7 +334,7 @@ Synapse config:
|
||||||
```yaml
|
```yaml
|
||||||
- idp_id: facebook
|
- idp_id: facebook
|
||||||
idp_name: Facebook
|
idp_name: Facebook
|
||||||
idp_brand: "org.matrix.facebook" # optional: styling hint for clients
|
idp_brand: "facebook" # optional: styling hint for clients
|
||||||
discover: false
|
discover: false
|
||||||
issuer: "https://facebook.com"
|
issuer: "https://facebook.com"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
@ -401,8 +401,7 @@ oidc_providers:
|
||||||
idp_name: "XWiki"
|
idp_name: "XWiki"
|
||||||
issuer: "https://myxwikihost/xwiki/oidc/"
|
issuer: "https://myxwikihost/xwiki/oidc/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
# Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
|
client_auth_method: none
|
||||||
client_secret: "dontcare"
|
|
||||||
scopes: ["openid", "profile"]
|
scopes: ["openid", "profile"]
|
||||||
user_profile_method: "userinfo_endpoint"
|
user_profile_method: "userinfo_endpoint"
|
||||||
user_mapping_provider:
|
user_mapping_provider:
|
||||||
|
|
@ -410,3 +409,40 @@ oidc_providers:
|
||||||
localpart_template: "{{ user.preferred_username }}"
|
localpart_template: "{{ user.preferred_username }}"
|
||||||
display_name_template: "{{ user.name }}"
|
display_name_template: "{{ user.name }}"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Apple
|
||||||
|
|
||||||
|
Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
|
||||||
|
|
||||||
|
You will need to create a new "Services ID" for SiWA, and create and download a
|
||||||
|
private key with "SiWA" enabled.
|
||||||
|
|
||||||
|
As well as the private key file, you will need:
|
||||||
|
* Client ID: the "identifier" you gave the "Services ID"
|
||||||
|
* Team ID: a 10-character ID associated with your developer account.
|
||||||
|
* Key ID: the 10-character identifier for the key.
|
||||||
|
|
||||||
|
https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
|
||||||
|
documentation on setting up SiWA.
|
||||||
|
|
||||||
|
The synapse config will look like this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- idp_id: apple
|
||||||
|
idp_name: Apple
|
||||||
|
issuer: "https://appleid.apple.com"
|
||||||
|
client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
|
||||||
|
client_auth_method: "client_secret_post"
|
||||||
|
client_secret_jwt_key:
|
||||||
|
key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
|
||||||
|
jwt_header:
|
||||||
|
alg: ES256
|
||||||
|
kid: "KEYIDCODE" # Set to the 10-char Key ID
|
||||||
|
jwt_payload:
|
||||||
|
iss: TEAMIDCODE # Set to the 10-char Team ID
|
||||||
|
scopes: ["name", "email", "openid"]
|
||||||
|
authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
|
||||||
|
user_mapping_provider:
|
||||||
|
config:
|
||||||
|
email_template: "{{ user.email }}"
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,9 @@
|
||||||
It is recommended to put a reverse proxy such as
|
It is recommended to put a reverse proxy such as
|
||||||
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
|
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
|
||||||
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
|
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
|
||||||
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
|
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
|
||||||
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
|
[HAProxy](https://www.haproxy.org/) or
|
||||||
|
[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
|
||||||
of doing so is that it means that you can expose the default https port
|
of doing so is that it means that you can expose the default https port
|
||||||
(443) to Matrix clients without needing to run Synapse with root
|
(443) to Matrix clients without needing to run Synapse with root
|
||||||
privileges.
|
privileges.
|
||||||
|
|
@ -162,6 +163,52 @@ backend matrix
|
||||||
server matrix 127.0.0.1:8008
|
server matrix 127.0.0.1:8008
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Relayd
|
||||||
|
|
||||||
|
```
|
||||||
|
table <webserver> { 127.0.0.1 }
|
||||||
|
table <matrixserver> { 127.0.0.1 }
|
||||||
|
|
||||||
|
http protocol "https" {
|
||||||
|
tls { no tlsv1.0, ciphers "HIGH" }
|
||||||
|
tls keypair "example.com"
|
||||||
|
match header set "X-Forwarded-For" value "$REMOTE_ADDR"
|
||||||
|
match header set "X-Forwarded-Proto" value "https"
|
||||||
|
|
||||||
|
# set CORS header for .well-known/matrix/server, .well-known/matrix/client
|
||||||
|
# httpd does not support setting headers, so do it here
|
||||||
|
match request path "/.well-known/matrix/*" tag "matrix-cors"
|
||||||
|
match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
|
||||||
|
|
||||||
|
pass quick path "/_matrix/*" forward to <matrixserver>
|
||||||
|
pass quick path "/_synapse/client/*" forward to <matrixserver>
|
||||||
|
|
||||||
|
# pass on non-matrix traffic to webserver
|
||||||
|
pass forward to <webserver>
|
||||||
|
}
|
||||||
|
|
||||||
|
relay "https_traffic" {
|
||||||
|
listen on egress port 443 tls
|
||||||
|
protocol "https"
|
||||||
|
forward to <matrixserver> port 8008 check tcp
|
||||||
|
forward to <webserver> port 8080 check tcp
|
||||||
|
}
|
||||||
|
|
||||||
|
http protocol "matrix" {
|
||||||
|
tls { no tlsv1.0, ciphers "HIGH" }
|
||||||
|
tls keypair "example.com"
|
||||||
|
block
|
||||||
|
pass quick path "/_matrix/*" forward to <matrixserver>
|
||||||
|
pass quick path "/_synapse/client/*" forward to <matrixserver>
|
||||||
|
}
|
||||||
|
|
||||||
|
relay "matrix_federation" {
|
||||||
|
listen on egress port 8448 tls
|
||||||
|
protocol "matrix"
|
||||||
|
forward to <matrixserver> port 8008 check tcp
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Homeserver Configuration
|
## Homeserver Configuration
|
||||||
|
|
||||||
You will also want to set `bind_addresses: ['127.0.0.1']` and
|
You will also want to set `bind_addresses: ['127.0.0.1']` and
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,7 @@ pid_file: DATADIR/homeserver.pid
|
||||||
# Whether to require authentication to retrieve profile data (avatars,
|
# Whether to require authentication to retrieve profile data (avatars,
|
||||||
# display names) of other users through the client API. Defaults to
|
# display names) of other users through the client API. Defaults to
|
||||||
# 'false'. Note that profile data is also available via the federation
|
# 'false'. Note that profile data is also available via the federation
|
||||||
# API, so this setting is of limited value if federation is enabled on
|
# API, unless allow_profile_lookup_over_federation is set to false.
|
||||||
# the server.
|
|
||||||
#
|
#
|
||||||
#require_auth_for_profile_requests: true
|
#require_auth_for_profile_requests: true
|
||||||
|
|
||||||
|
|
@ -1780,7 +1779,26 @@ saml2_config:
|
||||||
#
|
#
|
||||||
# client_id: Required. oauth2 client id to use.
|
# client_id: Required. oauth2 client id to use.
|
||||||
#
|
#
|
||||||
# client_secret: Required. oauth2 client secret to use.
|
# client_secret: oauth2 client secret to use. May be omitted if
|
||||||
|
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
|
||||||
|
#
|
||||||
|
# client_secret_jwt_key: Alternative to client_secret: details of a key used
|
||||||
|
# to create a JSON Web Token to be used as an OAuth2 client secret. If
|
||||||
|
# given, must be a dictionary with the following properties:
|
||||||
|
#
|
||||||
|
# key: a pem-encoded signing key. Must be a suitable key for the
|
||||||
|
# algorithm specified. Required unless 'key_file' is given.
|
||||||
|
#
|
||||||
|
# key_file: the path to file containing a pem-encoded signing key file.
|
||||||
|
# Required unless 'key' is given.
|
||||||
|
#
|
||||||
|
# jwt_header: a dictionary giving properties to include in the JWT
|
||||||
|
# header. Must include the key 'alg', giving the algorithm used to
|
||||||
|
# sign the JWT, such as "ES256", using the JWA identifiers in
|
||||||
|
# RFC7518.
|
||||||
|
#
|
||||||
|
# jwt_payload: an optional dictionary giving properties to include in
|
||||||
|
# the JWT payload. Normally this should include an 'iss' key.
|
||||||
#
|
#
|
||||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
|
@ -1901,7 +1919,7 @@ oidc_providers:
|
||||||
#
|
#
|
||||||
#- idp_id: github
|
#- idp_id: github
|
||||||
# idp_name: Github
|
# idp_name: Github
|
||||||
# idp_brand: org.matrix.github
|
# idp_brand: github
|
||||||
# discover: false
|
# discover: false
|
||||||
# issuer: "https://github.com/"
|
# issuer: "https://github.com/"
|
||||||
# client_id: "your-client-id" # TO BE FILLED
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
@ -2627,19 +2645,20 @@ user_directory:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Local statistics collection. Used in populating the room directory.
|
# Settings for local room and user statistics collection. See
|
||||||
|
# docs/room_and_user_statistics.md.
|
||||||
#
|
#
|
||||||
# 'bucket_size' controls how large each statistics timeslice is. It can
|
stats:
|
||||||
# be defined in a human readable short form -- e.g. "1d", "1y".
|
# Uncomment the following to disable room and user statistics. Note that doing
|
||||||
#
|
# so may cause certain features (such as the room directory) not to work
|
||||||
# 'retention' controls how long historical statistics will be kept for.
|
# correctly.
|
||||||
# It can be defined in a human readable short form -- e.g. "1d", "1y".
|
#
|
||||||
#
|
#enabled: false
|
||||||
#
|
|
||||||
#stats:
|
# The size of each timeslice in the room_stats_historical and
|
||||||
# enabled: true
|
# user_stats_historical tables, as a time period. Defaults to "1d".
|
||||||
# bucket_size: 1d
|
#
|
||||||
# retention: 1y
|
#bucket_size: 1h
|
||||||
|
|
||||||
|
|
||||||
# Server Notices room configuration
|
# Server Notices room configuration
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
|
||||||
* An instance of `synapse.module_api.ModuleApi`.
|
* An instance of `synapse.module_api.ModuleApi`.
|
||||||
|
|
||||||
It then implements methods which return a boolean to alter behavior in Synapse.
|
It then implements methods which return a boolean to alter behavior in Synapse.
|
||||||
|
All the methods must be defined.
|
||||||
|
|
||||||
There's a generic method for checking every event (`check_event_for_spam`), as
|
There's a generic method for checking every event (`check_event_for_spam`), as
|
||||||
well as some specific methods:
|
well as some specific methods:
|
||||||
|
|
@ -24,6 +25,7 @@ well as some specific methods:
|
||||||
* `user_may_publish_room`
|
* `user_may_publish_room`
|
||||||
* `check_username_for_spam`
|
* `check_username_for_spam`
|
||||||
* `check_registration_for_spam`
|
* `check_registration_for_spam`
|
||||||
|
* `check_media_file_for_spam`
|
||||||
|
|
||||||
The details of each of these methods (as well as their inputs and outputs)
|
The details of each of these methods (as well as their inputs and outputs)
|
||||||
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||||
|
|
@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||||
The `ModuleApi` class provides a way for the custom spam checker class to
|
The `ModuleApi` class provides a way for the custom spam checker class to
|
||||||
call back into the homeserver internals.
|
call back into the homeserver internals.
|
||||||
|
|
||||||
|
Additionally, a `parse_config` method is mandatory and receives the plugin config
|
||||||
|
dictionary. After parsing, It must return an object which will be
|
||||||
|
passed to `__init__` later.
|
||||||
|
|
||||||
### Example
|
### Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
@ -41,6 +47,10 @@ class ExampleSpamChecker:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.api = api
|
self.api = api
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config):
|
||||||
|
return config
|
||||||
|
|
||||||
async def check_event_for_spam(self, foo):
|
async def check_event_for_spam(self, foo):
|
||||||
return False # allow all events
|
return False # allow all events
|
||||||
|
|
||||||
|
|
@ -59,7 +69,13 @@ class ExampleSpamChecker:
|
||||||
async def check_username_for_spam(self, user_profile):
|
async def check_username_for_spam(self, user_profile):
|
||||||
return False # allow all usernames
|
return False # allow all usernames
|
||||||
|
|
||||||
async def check_registration_for_spam(self, email_threepid, username, request_info):
|
async def check_registration_for_spam(
|
||||||
|
self,
|
||||||
|
email_threepid,
|
||||||
|
username,
|
||||||
|
request_info,
|
||||||
|
auth_provider_id,
|
||||||
|
):
|
||||||
return RegistrationBehaviour.ALLOW # allow all registrations
|
return RegistrationBehaviour.ALLOW # allow all registrations
|
||||||
|
|
||||||
async def check_media_file_for_spam(self, file_wrapper, file_info):
|
async def check_media_file_for_spam(self, file_wrapper, file_info):
|
||||||
|
|
|
||||||
4
mypy.ini
4
mypy.ini
|
|
@ -69,6 +69,7 @@ files =
|
||||||
synapse/util/async_helpers.py,
|
synapse/util/async_helpers.py,
|
||||||
synapse/util/caches,
|
synapse/util/caches,
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
|
synapse/util/macaroons.py,
|
||||||
synapse/util/stringutils.py,
|
synapse/util/stringutils.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_utils,
|
tests/test_utils,
|
||||||
|
|
@ -116,9 +117,6 @@ ignore_missing_imports = True
|
||||||
[mypy-saml2.*]
|
[mypy-saml2.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-unpaddedbase64]
|
|
||||||
ignore_missing_imports = True
|
|
||||||
|
|
||||||
[mypy-canonicaljson]
|
[mypy-canonicaljson]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,14 @@
|
||||||
# Find linting errors in Synapse's default config file.
|
# Find linting errors in Synapse's default config file.
|
||||||
# Exits with 0 if there are no problems, or another code otherwise.
|
# Exits with 0 if there are no problems, or another code otherwise.
|
||||||
|
|
||||||
|
# cd to the root of the repository
|
||||||
|
cd `dirname $0`/..
|
||||||
|
|
||||||
|
# Restore backup of sample config upon script exit
|
||||||
|
trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
|
||||||
|
|
||||||
# Fix non-lowercase true/false values
|
# Fix non-lowercase true/false values
|
||||||
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
|
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
|
||||||
rm docs/sample_config.yaml.bak
|
|
||||||
|
|
||||||
# Check if anything changed
|
# Check if anything changed
|
||||||
git diff --exit-code docs/sample_config.yaml
|
diff docs/sample_config.yaml docs/sample_config.yaml.bak
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ test_suite = tests
|
||||||
|
|
||||||
[check-manifest]
|
[check-manifest]
|
||||||
ignore =
|
ignore =
|
||||||
|
.git-blame-ignore-revs
|
||||||
contrib
|
contrib
|
||||||
contrib/*
|
contrib/*
|
||||||
docs/*
|
docs/*
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,9 @@
|
||||||
"""
|
"""
|
||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
class RedisProtocol:
|
from twisted.internet import protocol
|
||||||
|
|
||||||
|
class RedisProtocol(protocol.Protocol):
|
||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
async def ping(self) -> None: ...
|
async def ping(self) -> None: ...
|
||||||
async def set(
|
async def set(
|
||||||
|
|
@ -52,7 +54,7 @@ def lazyConnection(
|
||||||
|
|
||||||
class ConnectionHandler: ...
|
class ConnectionHandler: ...
|
||||||
|
|
||||||
class RedisFactory:
|
class RedisFactory(protocol.ReconnectingClientFactory):
|
||||||
continueTrying: bool
|
continueTrying: bool
|
||||||
handler: RedisProtocol
|
handler: RedisProtocol
|
||||||
pool: List[RedisProtocol]
|
pool: List[RedisProtocol]
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.29.0"
|
__version__ = "1.30.0rc1"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -163,7 +164,7 @@ class Auth:
|
||||||
|
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
rights: str = "access",
|
rights: str = "access",
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
|
|
@ -408,7 +409,7 @@ class Auth:
|
||||||
raise _InvalidMacaroonException()
|
raise _InvalidMacaroonException()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||||
|
|
||||||
guest = False
|
guest = False
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
|
|
@ -416,7 +417,12 @@ class Auth:
|
||||||
guest = True
|
guest = True
|
||||||
|
|
||||||
self.validate_macaroon(macaroon, rights, user_id=user_id)
|
self.validate_macaroon(macaroon, rights, user_id=user_id)
|
||||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
except (
|
||||||
|
pymacaroons.exceptions.MacaroonException,
|
||||||
|
KeyError,
|
||||||
|
TypeError,
|
||||||
|
ValueError,
|
||||||
|
):
|
||||||
raise InvalidClientTokenError("Invalid macaroon passed.")
|
raise InvalidClientTokenError("Invalid macaroon passed.")
|
||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
|
|
@ -424,27 +430,6 @@ class Auth:
|
||||||
|
|
||||||
return user_id, guest
|
return user_id, guest
|
||||||
|
|
||||||
def get_user_id_from_macaroon(self, macaroon):
|
|
||||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
|
||||||
|
|
||||||
Does *not* validate the macaroon.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(str) user id
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidClientCredentialsError if there is no user_id caveat in the
|
|
||||||
macaroon
|
|
||||||
"""
|
|
||||||
user_prefix = "user_id = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(user_prefix):
|
|
||||||
return caveat.caveat_id[len(user_prefix) :]
|
|
||||||
raise InvalidClientTokenError("No user caveat in macaroon")
|
|
||||||
|
|
||||||
def validate_macaroon(self, macaroon, type_string, user_id):
|
def validate_macaroon(self, macaroon, type_string, user_id):
|
||||||
"""
|
"""
|
||||||
validate that a Macaroon is understood by and was signed by this server.
|
validate that a Macaroon is understood by and was signed by this server.
|
||||||
|
|
@ -465,21 +450,13 @@ class Auth:
|
||||||
v.satisfy_exact("type = " + type_string)
|
v.satisfy_exact("type = " + type_string)
|
||||||
v.satisfy_exact("user_id = %s" % user_id)
|
v.satisfy_exact("user_id = %s" % user_id)
|
||||||
v.satisfy_exact("guest = true")
|
v.satisfy_exact("guest = true")
|
||||||
v.satisfy_general(self._verify_expiry)
|
satisfy_expiry(v, self.clock.time_msec)
|
||||||
|
|
||||||
# access_tokens include a nonce for uniqueness: any value is acceptable
|
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix) :])
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||||
token = self.get_access_token_from_request(request)
|
token = self.get_access_token_from_request(request)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self.protocol_meta_cache = ResponseCache(
|
self.protocol_meta_cache = ResponseCache(
|
||||||
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
) # type: ResponseCache[Tuple[str, str]]
|
||||||
|
|
||||||
async def query_user(self, service, user_id):
|
async def query_user(self, service, user_id):
|
||||||
|
|
|
||||||
|
|
@ -212,9 +212,8 @@ class Config:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_file(cls, file_path, config_name):
|
def read_file(cls, file_path, config_name):
|
||||||
cls.check_file(file_path, config_name)
|
"""Deprecated: call read_file directly"""
|
||||||
with open(file_path) as file_stream:
|
return read_file(file_path, (config_name,))
|
||||||
return file_stream.read()
|
|
||||||
|
|
||||||
def read_template(self, filename: str) -> jinja2.Template:
|
def read_template(self, filename: str) -> jinja2.Template:
|
||||||
"""Load a template file from disk.
|
"""Load a template file from disk.
|
||||||
|
|
@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
return self._get_instance(key)
|
return self._get_instance(key)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
def read_file(file_path: Any, config_path: Iterable[str]) -> str:
|
||||||
|
"""Check the given file exists, and read it into a string
|
||||||
|
|
||||||
|
If it does not, emit an error indicating the problem
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: the file to be read
|
||||||
|
config_path: where in the configuration file_path came from, so that a useful
|
||||||
|
error can be emitted if it does not exist.
|
||||||
|
Returns:
|
||||||
|
content of the file.
|
||||||
|
Raises:
|
||||||
|
ConfigError if there is a problem reading the file.
|
||||||
|
"""
|
||||||
|
if not isinstance(file_path, str):
|
||||||
|
raise ConfigError("%r is not a string", config_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.stat(file_path)
|
||||||
|
with open(file_path) as file_stream:
|
||||||
|
return file_stream.read()
|
||||||
|
except OSError as e:
|
||||||
|
raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Config",
|
||||||
|
"RootConfig",
|
||||||
|
"ShardedWorkerHandlingConfig",
|
||||||
|
"RoutableShardedWorkerHandlingConfig",
|
||||||
|
"read_file",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
|
||||||
|
|
||||||
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
def get_instance(self, key: str) -> str: ...
|
def get_instance(self, key: str) -> str: ...
|
||||||
|
|
||||||
|
def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,10 @@ import threading
|
||||||
from string import Template
|
from string import Template
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.logger import (
|
from twisted.logger import (
|
||||||
|
ILogObserver,
|
||||||
LogBeginner,
|
LogBeginner,
|
||||||
STDLibLogObserver,
|
STDLibLogObserver,
|
||||||
eventAsText,
|
eventAsText,
|
||||||
|
|
@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
|
||||||
|
|
||||||
threadlocal = threading.local()
|
threadlocal = threading.local()
|
||||||
|
|
||||||
def _log(event):
|
@implementer(ILogObserver)
|
||||||
|
def _log(event: dict) -> None:
|
||||||
if "log_text" in event:
|
if "log_text" in event:
|
||||||
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
|
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Iterable, Optional, Tuple, Type
|
from typing import Iterable, Mapping, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
|
@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError, read_file
|
||||||
|
|
||||||
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
|
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
|
||||||
|
|
||||||
|
|
@ -97,7 +97,26 @@ class OIDCConfig(Config):
|
||||||
#
|
#
|
||||||
# client_id: Required. oauth2 client id to use.
|
# client_id: Required. oauth2 client id to use.
|
||||||
#
|
#
|
||||||
# client_secret: Required. oauth2 client secret to use.
|
# client_secret: oauth2 client secret to use. May be omitted if
|
||||||
|
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
|
||||||
|
#
|
||||||
|
# client_secret_jwt_key: Alternative to client_secret: details of a key used
|
||||||
|
# to create a JSON Web Token to be used as an OAuth2 client secret. If
|
||||||
|
# given, must be a dictionary with the following properties:
|
||||||
|
#
|
||||||
|
# key: a pem-encoded signing key. Must be a suitable key for the
|
||||||
|
# algorithm specified. Required unless 'key_file' is given.
|
||||||
|
#
|
||||||
|
# key_file: the path to file containing a pem-encoded signing key file.
|
||||||
|
# Required unless 'key' is given.
|
||||||
|
#
|
||||||
|
# jwt_header: a dictionary giving properties to include in the JWT
|
||||||
|
# header. Must include the key 'alg', giving the algorithm used to
|
||||||
|
# sign the JWT, such as "ES256", using the JWA identifiers in
|
||||||
|
# RFC7518.
|
||||||
|
#
|
||||||
|
# jwt_payload: an optional dictionary giving properties to include in
|
||||||
|
# the JWT payload. Normally this should include an 'iss' key.
|
||||||
#
|
#
|
||||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
|
@ -218,7 +237,7 @@ class OIDCConfig(Config):
|
||||||
#
|
#
|
||||||
#- idp_id: github
|
#- idp_id: github
|
||||||
# idp_name: Github
|
# idp_name: Github
|
||||||
# idp_brand: org.matrix.github
|
# idp_brand: github
|
||||||
# discover: false
|
# discover: false
|
||||||
# issuer: "https://github.com/"
|
# issuer: "https://github.com/"
|
||||||
# client_id: "your-client-id" # TO BE FILLED
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
@ -240,7 +259,7 @@ class OIDCConfig(Config):
|
||||||
# jsonschema definition of the configuration settings for an oidc identity provider
|
# jsonschema definition of the configuration settings for an oidc identity provider
|
||||||
OIDC_PROVIDER_CONFIG_SCHEMA = {
|
OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["issuer", "client_id", "client_secret"],
|
"required": ["issuer", "client_id"],
|
||||||
"properties": {
|
"properties": {
|
||||||
"idp_id": {
|
"idp_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|
@ -253,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"idp_icon": {"type": "string"},
|
"idp_icon": {"type": "string"},
|
||||||
"idp_brand": {
|
"idp_brand": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
# MSC2758-style namespaced identifier
|
"minLength": 1,
|
||||||
|
"maxLength": 255,
|
||||||
|
"pattern": "^[a-z][a-z0-9_.-]*$",
|
||||||
|
},
|
||||||
|
"idp_unstable_brand": {
|
||||||
|
"type": "string",
|
||||||
"minLength": 1,
|
"minLength": 1,
|
||||||
"maxLength": 255,
|
"maxLength": 255,
|
||||||
"pattern": "^[a-z][a-z0-9_.-]*$",
|
"pattern": "^[a-z][a-z0-9_.-]*$",
|
||||||
|
|
@ -262,6 +286,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"issuer": {"type": "string"},
|
"issuer": {"type": "string"},
|
||||||
"client_id": {"type": "string"},
|
"client_id": {"type": "string"},
|
||||||
"client_secret": {"type": "string"},
|
"client_secret": {"type": "string"},
|
||||||
|
"client_secret_jwt_key": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["jwt_header"],
|
||||||
|
"oneOf": [
|
||||||
|
{"required": ["key"]},
|
||||||
|
{"required": ["key_file"]},
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"key": {"type": "string"},
|
||||||
|
"key_file": {"type": "string"},
|
||||||
|
"jwt_header": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["alg"],
|
||||||
|
"properties": {
|
||||||
|
"alg": {"type": "string"},
|
||||||
|
},
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
},
|
||||||
|
"jwt_payload": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"client_auth_method": {
|
"client_auth_method": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
# the following list is the same as the keys of
|
# the following list is the same as the keys of
|
||||||
|
|
@ -404,15 +452,31 @@ def _parse_oidc_config_dict(
|
||||||
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
|
||||||
|
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
|
||||||
|
if client_secret_jwt_key_config is not None:
|
||||||
|
keyfile = client_secret_jwt_key_config.get("key_file")
|
||||||
|
if keyfile:
|
||||||
|
key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
|
||||||
|
else:
|
||||||
|
key = client_secret_jwt_key_config["key"]
|
||||||
|
client_secret_jwt_key = OidcProviderClientSecretJwtKey(
|
||||||
|
key=key,
|
||||||
|
jwt_header=client_secret_jwt_key_config["jwt_header"],
|
||||||
|
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
|
||||||
|
)
|
||||||
|
|
||||||
return OidcProviderConfig(
|
return OidcProviderConfig(
|
||||||
idp_id=idp_id,
|
idp_id=idp_id,
|
||||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
idp_icon=idp_icon,
|
idp_icon=idp_icon,
|
||||||
idp_brand=oidc_config.get("idp_brand"),
|
idp_brand=oidc_config.get("idp_brand"),
|
||||||
|
unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
|
||||||
discover=oidc_config.get("discover", True),
|
discover=oidc_config.get("discover", True),
|
||||||
issuer=oidc_config["issuer"],
|
issuer=oidc_config["issuer"],
|
||||||
client_id=oidc_config["client_id"],
|
client_id=oidc_config["client_id"],
|
||||||
client_secret=oidc_config["client_secret"],
|
client_secret=oidc_config.get("client_secret"),
|
||||||
|
client_secret_jwt_key=client_secret_jwt_key,
|
||||||
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
||||||
scopes=oidc_config.get("scopes", ["openid"]),
|
scopes=oidc_config.get("scopes", ["openid"]),
|
||||||
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
||||||
|
|
@ -427,6 +491,18 @@ def _parse_oidc_config_dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class OidcProviderClientSecretJwtKey:
|
||||||
|
# a pem-encoded signing key
|
||||||
|
key = attr.ib(type=str)
|
||||||
|
|
||||||
|
# properties to include in the JWT header
|
||||||
|
jwt_header = attr.ib(type=Mapping[str, str])
|
||||||
|
|
||||||
|
# properties to include in the JWT payload.
|
||||||
|
jwt_payload = attr.ib(type=Mapping[str, str])
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class OidcProviderConfig:
|
class OidcProviderConfig:
|
||||||
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||||
|
|
@ -442,6 +518,9 @@ class OidcProviderConfig:
|
||||||
# Optional brand identifier for this IdP.
|
# Optional brand identifier for this IdP.
|
||||||
idp_brand = attr.ib(type=Optional[str])
|
idp_brand = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
# Optional brand identifier for the unstable API (see MSC2858).
|
||||||
|
unstable_idp_brand = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||||
discover = attr.ib(type=bool)
|
discover = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
@ -452,8 +531,13 @@ class OidcProviderConfig:
|
||||||
# oauth2 client id to use
|
# oauth2 client id to use
|
||||||
client_id = attr.ib(type=str)
|
client_id = attr.ib(type=str)
|
||||||
|
|
||||||
# oauth2 client secret to use
|
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
|
||||||
client_secret = attr.ib(type=str)
|
# a secret.
|
||||||
|
client_secret = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
# key to use to construct a JWT to use as a client secret. May be `None` if
|
||||||
|
# `client_secret` is set.
|
||||||
|
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
|
||||||
|
|
||||||
# auth method to use when exchanging the token.
|
# auth method to use when exchanging the token.
|
||||||
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
||||||
|
|
|
||||||
|
|
@ -841,8 +841,7 @@ class ServerConfig(Config):
|
||||||
# Whether to require authentication to retrieve profile data (avatars,
|
# Whether to require authentication to retrieve profile data (avatars,
|
||||||
# display names) of other users through the client API. Defaults to
|
# display names) of other users through the client API. Defaults to
|
||||||
# 'false'. Note that profile data is also available via the federation
|
# 'false'. Note that profile data is also available via the federation
|
||||||
# API, so this setting is of limited value if federation is enabled on
|
# API, unless allow_profile_lookup_over_federation is set to false.
|
||||||
# the server.
|
|
||||||
#
|
#
|
||||||
#require_auth_for_profile_requests: true
|
#require_auth_for_profile_requests: true
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import sys
|
import logging
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
ROOM_STATS_DISABLED_WARN = """\
|
||||||
|
WARNING: room/user statistics have been disabled via the stats.enabled
|
||||||
|
configuration setting. This means that certain features (such as the room
|
||||||
|
directory) will not operate correctly. Future versions of Synapse may ignore
|
||||||
|
this setting.
|
||||||
|
|
||||||
|
To fix this warning, remove the stats.enabled setting from your configuration
|
||||||
|
file.
|
||||||
|
--------------------------------------------------------------------------------"""
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StatsConfig(Config):
|
class StatsConfig(Config):
|
||||||
"""Stats Configuration
|
"""Stats Configuration
|
||||||
|
|
@ -28,30 +40,29 @@ class StatsConfig(Config):
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.stats_enabled = True
|
self.stats_enabled = True
|
||||||
self.stats_bucket_size = 86400 * 1000
|
self.stats_bucket_size = 86400 * 1000
|
||||||
self.stats_retention = sys.maxsize
|
|
||||||
stats_config = config.get("stats", None)
|
stats_config = config.get("stats", None)
|
||||||
if stats_config:
|
if stats_config:
|
||||||
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
|
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
|
||||||
self.stats_bucket_size = self.parse_duration(
|
self.stats_bucket_size = self.parse_duration(
|
||||||
stats_config.get("bucket_size", "1d")
|
stats_config.get("bucket_size", "1d")
|
||||||
)
|
)
|
||||||
self.stats_retention = self.parse_duration(
|
if not self.stats_enabled:
|
||||||
stats_config.get("retention", "%ds" % (sys.maxsize,))
|
logger.warning(ROOM_STATS_DISABLED_WARN)
|
||||||
)
|
|
||||||
|
|
||||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """
|
||||||
# Local statistics collection. Used in populating the room directory.
|
# Settings for local room and user statistics collection. See
|
||||||
|
# docs/room_and_user_statistics.md.
|
||||||
#
|
#
|
||||||
# 'bucket_size' controls how large each statistics timeslice is. It can
|
stats:
|
||||||
# be defined in a human readable short form -- e.g. "1d", "1y".
|
# Uncomment the following to disable room and user statistics. Note that doing
|
||||||
#
|
# so may cause certain features (such as the room directory) not to work
|
||||||
# 'retention' controls how long historical statistics will be kept for.
|
# correctly.
|
||||||
# It can be defined in a human readable short form -- e.g. "1d", "1y".
|
#
|
||||||
#
|
#enabled: false
|
||||||
#
|
|
||||||
#stats:
|
# The size of each timeslice in the room_stats_historical and
|
||||||
# enabled: true
|
# user_stats_historical tables, as a time period. Defaults to "1d".
|
||||||
# bucket_size: 1d
|
#
|
||||||
# retention: 1y
|
#bucket_size: 1h
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
|
|
@ -27,6 +28,8 @@ if TYPE_CHECKING:
|
||||||
import synapse.events
|
import synapse.events
|
||||||
import synapse.server
|
import synapse.server
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SpamChecker:
|
class SpamChecker:
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||||
|
|
@ -190,6 +193,7 @@ class SpamChecker:
|
||||||
email_threepid: Optional[dict],
|
email_threepid: Optional[dict],
|
||||||
username: Optional[str],
|
username: Optional[str],
|
||||||
request_info: Collection[Tuple[str, str]],
|
request_info: Collection[Tuple[str, str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> RegistrationBehaviour:
|
) -> RegistrationBehaviour:
|
||||||
"""Checks if we should allow the given registration request.
|
"""Checks if we should allow the given registration request.
|
||||||
|
|
||||||
|
|
@ -198,6 +202,9 @@ class SpamChecker:
|
||||||
username: The request user name, if any
|
username: The request user name, if any
|
||||||
request_info: List of tuples of user agent and IP that
|
request_info: List of tuples of user agent and IP that
|
||||||
were used during the registration process.
|
were used during the registration process.
|
||||||
|
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
|
||||||
|
"cas". If any. Note this does not include users registered
|
||||||
|
via a password provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Enum for how the request should be handled
|
Enum for how the request should be handled
|
||||||
|
|
@ -208,9 +215,25 @@ class SpamChecker:
|
||||||
# spam checker
|
# spam checker
|
||||||
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||||
if checker:
|
if checker:
|
||||||
behaviour = await maybe_awaitable(
|
# Provide auth_provider_id if the function supports it
|
||||||
checker(email_threepid, username, request_info)
|
checker_args = inspect.signature(checker)
|
||||||
)
|
if len(checker_args.parameters) == 4:
|
||||||
|
d = checker(
|
||||||
|
email_threepid,
|
||||||
|
username,
|
||||||
|
request_info,
|
||||||
|
auth_provider_id,
|
||||||
|
)
|
||||||
|
elif len(checker_args.parameters) == 3:
|
||||||
|
d = checker(email_threepid, username, request_info)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Invalid signature for %s.check_registration_for_spam. Denying registration",
|
||||||
|
spam_checker.__module__,
|
||||||
|
)
|
||||||
|
return RegistrationBehaviour.DENY
|
||||||
|
|
||||||
|
behaviour = await maybe_awaitable(d)
|
||||||
assert isinstance(behaviour, RegistrationBehaviour)
|
assert isinstance(behaviour, RegistrationBehaviour)
|
||||||
if behaviour != RegistrationBehaviour.ALLOW:
|
if behaviour != RegistrationBehaviour.ALLOW:
|
||||||
return behaviour
|
return behaviour
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
|
@ -90,16 +91,15 @@ pdu_process_time = Histogram(
|
||||||
"Time taken to process an event",
|
"Time taken to process an event",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
last_pdu_ts_metric = Gauge(
|
||||||
last_pdu_age_metric = Gauge(
|
"synapse_federation_last_received_pdu_time",
|
||||||
"synapse_federation_last_received_pdu_age",
|
"The timestamp of the last PDU which was successfully received from the given domain",
|
||||||
"The age (in seconds) of the last PDU successfully received from the given domain",
|
|
||||||
labelnames=("server_name",),
|
labelnames=("server_name",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederationServer(FederationBase):
|
class FederationServer(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
@ -112,14 +112,15 @@ class FederationServer(FederationBase):
|
||||||
# with FederationHandlerRegistry.
|
# with FederationHandlerRegistry.
|
||||||
hs.get_directory_handler()
|
hs.get_directory_handler()
|
||||||
|
|
||||||
self._federation_ratelimiter = hs.get_federation_ratelimiter()
|
|
||||||
|
|
||||||
self._server_linearizer = Linearizer("fed_server")
|
self._server_linearizer = Linearizer("fed_server")
|
||||||
self._transaction_linearizer = Linearizer("fed_txn_handler")
|
|
||||||
|
# origins that we are currently processing a transaction from.
|
||||||
|
# a dict from origin to txn id.
|
||||||
|
self._active_transactions = {} # type: Dict[str, str]
|
||||||
|
|
||||||
# We cache results for transaction with the same ID
|
# We cache results for transaction with the same ID
|
||||||
self._transaction_resp_cache = ResponseCache(
|
self._transaction_resp_cache = ResponseCache(
|
||||||
hs, "fed_txn_handler", timeout_ms=30000
|
hs.get_clock(), "fed_txn_handler", timeout_ms=30000
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
) # type: ResponseCache[Tuple[str, str]]
|
||||||
|
|
||||||
self.transaction_actions = TransactionActions(self.store)
|
self.transaction_actions = TransactionActions(self.store)
|
||||||
|
|
@ -129,10 +130,10 @@ class FederationServer(FederationBase):
|
||||||
# We cache responses to state queries, as they take a while and often
|
# We cache responses to state queries, as they take a while and often
|
||||||
# come in waves.
|
# come in waves.
|
||||||
self._state_resp_cache = ResponseCache(
|
self._state_resp_cache = ResponseCache(
|
||||||
hs, "state_resp", timeout_ms=30000
|
hs.get_clock(), "state_resp", timeout_ms=30000
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
) # type: ResponseCache[Tuple[str, str]]
|
||||||
self._state_ids_resp_cache = ResponseCache(
|
self._state_ids_resp_cache = ResponseCache(
|
||||||
hs, "state_ids_resp", timeout_ms=30000
|
hs.get_clock(), "state_ids_resp", timeout_ms=30000
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
) # type: ResponseCache[Tuple[str, str]]
|
||||||
|
|
||||||
self._federation_metrics_domains = (
|
self._federation_metrics_domains = (
|
||||||
|
|
@ -169,6 +170,33 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
logger.debug("[%s] Got transaction", transaction_id)
|
logger.debug("[%s] Got transaction", transaction_id)
|
||||||
|
|
||||||
|
# Reject malformed transactions early: reject if too many PDUs/EDUs
|
||||||
|
if len(transaction.pdus) > 50 or ( # type: ignore
|
||||||
|
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
|
||||||
|
):
|
||||||
|
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||||
|
return 400, {}
|
||||||
|
|
||||||
|
# we only process one transaction from each origin at a time. We need to do
|
||||||
|
# this check here, rather than in _on_incoming_transaction_inner so that we
|
||||||
|
# don't cache the rejection in _transaction_resp_cache (so that if the txn
|
||||||
|
# arrives again later, we can process it).
|
||||||
|
current_transaction = self._active_transactions.get(origin)
|
||||||
|
if current_transaction and current_transaction != transaction_id:
|
||||||
|
logger.warning(
|
||||||
|
"Received another txn %s from %s while still processing %s",
|
||||||
|
transaction_id,
|
||||||
|
origin,
|
||||||
|
current_transaction,
|
||||||
|
)
|
||||||
|
return 429, {
|
||||||
|
"errcode": Codes.UNKNOWN,
|
||||||
|
"error": "Too many concurrent transactions",
|
||||||
|
}
|
||||||
|
|
||||||
|
# CRITICAL SECTION: we must now not await until we populate _active_transactions
|
||||||
|
# in _on_incoming_transaction_inner.
|
||||||
|
|
||||||
# We wrap in a ResponseCache so that we de-duplicate retried
|
# We wrap in a ResponseCache so that we de-duplicate retried
|
||||||
# transactions.
|
# transactions.
|
||||||
return await self._transaction_resp_cache.wrap(
|
return await self._transaction_resp_cache.wrap(
|
||||||
|
|
@ -182,26 +210,18 @@ class FederationServer(FederationBase):
|
||||||
async def _on_incoming_transaction_inner(
|
async def _on_incoming_transaction_inner(
|
||||||
self, origin: str, transaction: Transaction, request_time: int
|
self, origin: str, transaction: Transaction, request_time: int
|
||||||
) -> Tuple[int, Dict[str, Any]]:
|
) -> Tuple[int, Dict[str, Any]]:
|
||||||
# Use a linearizer to ensure that transactions from a remote are
|
# CRITICAL SECTION: the first thing we must do (before awaiting) is
|
||||||
# processed in order.
|
# add an entry to _active_transactions.
|
||||||
with await self._transaction_linearizer.queue(origin):
|
assert origin not in self._active_transactions
|
||||||
# We rate limit here *after* we've queued up the incoming requests,
|
self._active_transactions[origin] = transaction.transaction_id # type: ignore
|
||||||
# so that we don't fill up the ratelimiter with blocked requests.
|
|
||||||
#
|
|
||||||
# This is important as the ratelimiter allows N concurrent requests
|
|
||||||
# at a time, and only starts ratelimiting if there are more requests
|
|
||||||
# than that being processed at a time. If we queued up requests in
|
|
||||||
# the linearizer/response cache *after* the ratelimiting then those
|
|
||||||
# queued up requests would count as part of the allowed limit of N
|
|
||||||
# concurrent requests.
|
|
||||||
with self._federation_ratelimiter.ratelimit(origin) as d:
|
|
||||||
await d
|
|
||||||
|
|
||||||
result = await self._handle_incoming_transaction(
|
try:
|
||||||
origin, transaction, request_time
|
result = await self._handle_incoming_transaction(
|
||||||
)
|
origin, transaction, request_time
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
finally:
|
||||||
|
del self._active_transactions[origin]
|
||||||
|
|
||||||
async def _handle_incoming_transaction(
|
async def _handle_incoming_transaction(
|
||||||
self, origin: str, transaction: Transaction, request_time: int
|
self, origin: str, transaction: Transaction, request_time: int
|
||||||
|
|
@ -227,19 +247,6 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
|
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
|
||||||
|
|
||||||
# Reject if PDU count > 50 or EDU count > 100
|
|
||||||
if len(transaction.pdus) > 50 or ( # type: ignore
|
|
||||||
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
|
|
||||||
):
|
|
||||||
|
|
||||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
|
||||||
|
|
||||||
response = {}
|
|
||||||
await self.transaction_actions.set_response(
|
|
||||||
origin, transaction, 400, response
|
|
||||||
)
|
|
||||||
return 400, response
|
|
||||||
|
|
||||||
# We process PDUs and EDUs in parallel. This is important as we don't
|
# We process PDUs and EDUs in parallel. This is important as we don't
|
||||||
# want to block things like to device messages from reaching clients
|
# want to block things like to device messages from reaching clients
|
||||||
# behind the potentially expensive handling of PDUs.
|
# behind the potentially expensive handling of PDUs.
|
||||||
|
|
@ -335,42 +342,48 @@ class FederationServer(FederationBase):
|
||||||
# impose a limit to avoid going too crazy with ram/cpu.
|
# impose a limit to avoid going too crazy with ram/cpu.
|
||||||
|
|
||||||
async def process_pdus_for_room(room_id: str):
|
async def process_pdus_for_room(room_id: str):
|
||||||
logger.debug("Processing PDUs for %s", room_id)
|
with nested_logging_context(room_id):
|
||||||
try:
|
logger.debug("Processing PDUs for %s", room_id)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
|
||||||
except AuthError as e:
|
|
||||||
logger.warning("Ignoring PDUs for room %s from banned server", room_id)
|
|
||||||
for pdu in pdus_by_room[room_id]:
|
|
||||||
event_id = pdu.event_id
|
|
||||||
pdu_results[event_id] = e.error_dict()
|
|
||||||
return
|
|
||||||
|
|
||||||
for pdu in pdus_by_room[room_id]:
|
try:
|
||||||
event_id = pdu.event_id
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
with pdu_process_time.time():
|
except AuthError as e:
|
||||||
with nested_logging_context(event_id):
|
logger.warning(
|
||||||
try:
|
"Ignoring PDUs for room %s from banned server", room_id
|
||||||
await self._handle_received_pdu(origin, pdu)
|
)
|
||||||
pdu_results[event_id] = {}
|
for pdu in pdus_by_room[room_id]:
|
||||||
except FederationError as e:
|
event_id = pdu.event_id
|
||||||
logger.warning("Error handling PDU %s: %s", event_id, e)
|
pdu_results[event_id] = e.error_dict()
|
||||||
pdu_results[event_id] = {"error": str(e)}
|
return
|
||||||
except Exception as e:
|
|
||||||
f = failure.Failure()
|
for pdu in pdus_by_room[room_id]:
|
||||||
pdu_results[event_id] = {"error": str(e)}
|
pdu_results[pdu.event_id] = await process_pdu(pdu)
|
||||||
logger.error(
|
|
||||||
"Failed to handle PDU %s",
|
async def process_pdu(pdu: EventBase) -> JsonDict:
|
||||||
event_id,
|
event_id = pdu.event_id
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
with pdu_process_time.time():
|
||||||
)
|
with nested_logging_context(event_id):
|
||||||
|
try:
|
||||||
|
await self._handle_received_pdu(origin, pdu)
|
||||||
|
return {}
|
||||||
|
except FederationError as e:
|
||||||
|
logger.warning("Error handling PDU %s: %s", event_id, e)
|
||||||
|
return {"error": str(e)}
|
||||||
|
except Exception as e:
|
||||||
|
f = failure.Failure()
|
||||||
|
logger.error(
|
||||||
|
"Failed to handle PDU %s",
|
||||||
|
event_id,
|
||||||
|
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||||
|
)
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
await concurrently_execute(
|
await concurrently_execute(
|
||||||
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
||||||
)
|
)
|
||||||
|
|
||||||
if newest_pdu_ts and origin in self._federation_metrics_domains:
|
if newest_pdu_ts and origin in self._federation_metrics_domains:
|
||||||
newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
|
last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
|
||||||
last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
|
|
||||||
|
|
||||||
return pdu_results
|
return pdu_results
|
||||||
|
|
||||||
|
|
@ -448,18 +461,22 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
async def _on_state_ids_request_compute(self, room_id, event_id):
|
async def _on_state_ids_request_compute(self, room_id, event_id):
|
||||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||||
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
|
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
|
||||||
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||||
|
|
||||||
async def _on_context_state_request_compute(
|
async def _on_context_state_request_compute(
|
||||||
self, room_id: str, event_id: str
|
self, room_id: str, event_id: str
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
if event_id:
|
if event_id:
|
||||||
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
|
pdus = await self.handler.get_state_for_pdu(
|
||||||
|
room_id, event_id
|
||||||
|
) # type: Iterable[EventBase]
|
||||||
else:
|
else:
|
||||||
pdus = (await self.state.get_current_state(room_id)).values()
|
pdus = (await self.state.get_current_state(room_id)).values()
|
||||||
|
|
||||||
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
|
auth_chain = await self.store.get_auth_chain(
|
||||||
|
room_id, [pdu.event_id for pdu in pdus]
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||||
|
|
@ -863,7 +880,9 @@ class FederationHandlerRegistry:
|
||||||
self.edu_handlers = (
|
self.edu_handlers = (
|
||||||
{}
|
{}
|
||||||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
||||||
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
self.query_handlers = (
|
||||||
|
{}
|
||||||
|
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
|
||||||
|
|
||||||
# Map from type to instance names that we should route EDU handling to.
|
# Map from type to instance names that we should route EDU handling to.
|
||||||
# We randomly choose one instance from the list to route to for each new
|
# We randomly choose one instance from the list to route to for each new
|
||||||
|
|
@ -897,7 +916,7 @@ class FederationHandlerRegistry:
|
||||||
self.edu_handlers[edu_type] = handler
|
self.edu_handlers[edu_type] = handler
|
||||||
|
|
||||||
def register_query_handler(
|
def register_query_handler(
|
||||||
self, query_type: str, handler: Callable[[dict], defer.Deferred]
|
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
):
|
):
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation query of the given type.
|
federation query of the given type.
|
||||||
|
|
@ -970,7 +989,7 @@ class FederationHandlerRegistry:
|
||||||
# Oh well, let's just log and move on.
|
# Oh well, let's just log and move on.
|
||||||
logger.warning("No handler registered for EDU type %s", edu_type)
|
logger.warning("No handler registered for EDU type %s", edu_type)
|
||||||
|
|
||||||
async def on_query(self, query_type: str, args: dict):
|
async def on_query(self, query_type: str, args: dict) -> JsonDict:
|
||||||
handler = self.query_handlers.get(query_type)
|
handler = self.query_handlers.get(query_type)
|
||||||
if handler:
|
if handler:
|
||||||
return await handler(args)
|
return await handler(args)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
|
|
@ -93,6 +94,10 @@ class PerDestinationQueue:
|
||||||
self._destination = destination
|
self._destination = destination
|
||||||
self.transmission_loop_running = False
|
self.transmission_loop_running = False
|
||||||
|
|
||||||
|
# Flag to signal to any running transmission loop that there is new data
|
||||||
|
# queued up to be sent.
|
||||||
|
self._new_data_to_send = False
|
||||||
|
|
||||||
# True whilst we are sending events that the remote homeserver missed
|
# True whilst we are sending events that the remote homeserver missed
|
||||||
# because it was unreachable. We start in this state so we can perform
|
# because it was unreachable. We start in this state so we can perform
|
||||||
# catch-up at startup.
|
# catch-up at startup.
|
||||||
|
|
@ -108,7 +113,7 @@ class PerDestinationQueue:
|
||||||
# destination (we are the only updater so this is safe)
|
# destination (we are the only updater so this is safe)
|
||||||
self._last_successful_stream_ordering = None # type: Optional[int]
|
self._last_successful_stream_ordering = None # type: Optional[int]
|
||||||
|
|
||||||
# a list of pending PDUs
|
# a queue of pending PDUs
|
||||||
self._pending_pdus = [] # type: List[EventBase]
|
self._pending_pdus = [] # type: List[EventBase]
|
||||||
|
|
||||||
# XXX this is never actually used: see
|
# XXX this is never actually used: see
|
||||||
|
|
@ -208,6 +213,10 @@ class PerDestinationQueue:
|
||||||
transaction in the background.
|
transaction in the background.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Mark that we (may) have new things to send, so that any running
|
||||||
|
# transmission loop will recheck whether there is stuff to send.
|
||||||
|
self._new_data_to_send = True
|
||||||
|
|
||||||
if self.transmission_loop_running:
|
if self.transmission_loop_running:
|
||||||
# XXX: this can get stuck on by a never-ending
|
# XXX: this can get stuck on by a never-ending
|
||||||
# request at which point pending_pdus just keeps growing.
|
# request at which point pending_pdus just keeps growing.
|
||||||
|
|
@ -250,125 +259,41 @@ class PerDestinationQueue:
|
||||||
|
|
||||||
pending_pdus = []
|
pending_pdus = []
|
||||||
while True:
|
while True:
|
||||||
# We have to keep 2 free slots for presence and rr_edus
|
self._new_data_to_send = False
|
||||||
limit = MAX_EDUS_PER_TRANSACTION - 2
|
|
||||||
|
|
||||||
device_update_edus, dev_list_id = await self._get_device_update_edus(
|
async with _TransactionQueueManager(self) as (
|
||||||
limit
|
pending_pdus,
|
||||||
)
|
pending_edus,
|
||||||
|
|
||||||
limit -= len(device_update_edus)
|
|
||||||
|
|
||||||
(
|
|
||||||
to_device_edus,
|
|
||||||
device_stream_id,
|
|
||||||
) = await self._get_to_device_message_edus(limit)
|
|
||||||
|
|
||||||
pending_edus = device_update_edus + to_device_edus
|
|
||||||
|
|
||||||
# BEGIN CRITICAL SECTION
|
|
||||||
#
|
|
||||||
# In order to avoid a race condition, we need to make sure that
|
|
||||||
# the following code (from popping the queues up to the point
|
|
||||||
# where we decide if we actually have any pending messages) is
|
|
||||||
# atomic - otherwise new PDUs or EDUs might arrive in the
|
|
||||||
# meantime, but not get sent because we hold the
|
|
||||||
# transmission_loop_running flag.
|
|
||||||
|
|
||||||
pending_pdus = self._pending_pdus
|
|
||||||
|
|
||||||
# We can only include at most 50 PDUs per transactions
|
|
||||||
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
|
|
||||||
|
|
||||||
pending_edus.extend(self._get_rr_edus(force_flush=False))
|
|
||||||
pending_presence = self._pending_presence
|
|
||||||
self._pending_presence = {}
|
|
||||||
if pending_presence:
|
|
||||||
pending_edus.append(
|
|
||||||
Edu(
|
|
||||||
origin=self._server_name,
|
|
||||||
destination=self._destination,
|
|
||||||
edu_type="m.presence",
|
|
||||||
content={
|
|
||||||
"push": [
|
|
||||||
format_user_presence_state(
|
|
||||||
presence, self._clock.time_msec()
|
|
||||||
)
|
|
||||||
for presence in pending_presence.values()
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pending_edus.extend(
|
|
||||||
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
|
|
||||||
)
|
|
||||||
while (
|
|
||||||
len(pending_edus) < MAX_EDUS_PER_TRANSACTION
|
|
||||||
and self._pending_edus_keyed
|
|
||||||
):
|
):
|
||||||
_, val = self._pending_edus_keyed.popitem()
|
if not pending_pdus and not pending_edus:
|
||||||
pending_edus.append(val)
|
logger.debug("TX [%s] Nothing to send", self._destination)
|
||||||
|
|
||||||
if pending_pdus:
|
# If we've gotten told about new things to send during
|
||||||
logger.debug(
|
# checking for things to send, we try looking again.
|
||||||
"TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
# Otherwise new PDUs or EDUs might arrive in the meantime,
|
||||||
self._destination,
|
# but not get sent because we hold the
|
||||||
len(pending_pdus),
|
# `transmission_loop_running` flag.
|
||||||
|
if self._new_data_to_send:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if pending_pdus:
|
||||||
|
logger.debug(
|
||||||
|
"TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
|
self._destination,
|
||||||
|
len(pending_pdus),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._transaction_manager.send_new_transaction(
|
||||||
|
self._destination, pending_pdus, pending_edus
|
||||||
)
|
)
|
||||||
|
|
||||||
if not pending_pdus and not pending_edus:
|
|
||||||
logger.debug("TX [%s] Nothing to send", self._destination)
|
|
||||||
self._last_device_stream_id = device_stream_id
|
|
||||||
return
|
|
||||||
|
|
||||||
# if we've decided to send a transaction anyway, and we have room, we
|
|
||||||
# may as well send any pending RRs
|
|
||||||
if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
|
|
||||||
pending_edus.extend(self._get_rr_edus(force_flush=True))
|
|
||||||
|
|
||||||
# END CRITICAL SECTION
|
|
||||||
|
|
||||||
success = await self._transaction_manager.send_new_transaction(
|
|
||||||
self._destination, pending_pdus, pending_edus
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
sent_transactions_counter.inc()
|
sent_transactions_counter.inc()
|
||||||
sent_edus_counter.inc(len(pending_edus))
|
sent_edus_counter.inc(len(pending_edus))
|
||||||
for edu in pending_edus:
|
for edu in pending_edus:
|
||||||
sent_edus_by_type.labels(edu.edu_type).inc()
|
sent_edus_by_type.labels(edu.edu_type).inc()
|
||||||
# Remove the acknowledged device messages from the database
|
|
||||||
# Only bother if we actually sent some device messages
|
|
||||||
if to_device_edus:
|
|
||||||
await self._store.delete_device_msgs_for_remote(
|
|
||||||
self._destination, device_stream_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# also mark the device updates as sent
|
|
||||||
if device_update_edus:
|
|
||||||
logger.info(
|
|
||||||
"Marking as sent %r %r", self._destination, dev_list_id
|
|
||||||
)
|
|
||||||
await self._store.mark_as_sent_devices_by_remote(
|
|
||||||
self._destination, dev_list_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self._last_device_stream_id = device_stream_id
|
|
||||||
self._last_device_list_stream_id = dev_list_id
|
|
||||||
|
|
||||||
if pending_pdus:
|
|
||||||
# we sent some PDUs and it was successful, so update our
|
|
||||||
# last_successful_stream_ordering in the destinations table.
|
|
||||||
final_pdu = pending_pdus[-1]
|
|
||||||
last_successful_stream_ordering = (
|
|
||||||
final_pdu.internal_metadata.stream_ordering
|
|
||||||
)
|
|
||||||
assert last_successful_stream_ordering
|
|
||||||
await self._store.set_destination_last_successful_stream_ordering(
|
|
||||||
self._destination, last_successful_stream_ordering
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TX [%s] not ready for retry yet (next retry at %s) - "
|
"TX [%s] not ready for retry yet (next retry at %s) - "
|
||||||
|
|
@ -401,7 +326,7 @@ class PerDestinationQueue:
|
||||||
self._pending_presence = {}
|
self._pending_presence = {}
|
||||||
self._pending_rrs = {}
|
self._pending_rrs = {}
|
||||||
|
|
||||||
self._start_catching_up()
|
self._start_catching_up()
|
||||||
except FederationDeniedError as e:
|
except FederationDeniedError as e:
|
||||||
logger.info(e)
|
logger.info(e)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
|
|
@ -412,7 +337,6 @@ class PerDestinationQueue:
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._start_catching_up()
|
|
||||||
except RequestSendFailed as e:
|
except RequestSendFailed as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"TX [%s] Failed to send transaction: %s", self._destination, e
|
"TX [%s] Failed to send transaction: %s", self._destination, e
|
||||||
|
|
@ -422,16 +346,12 @@ class PerDestinationQueue:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Failed to send event %s to %s", p.event_id, self._destination
|
"Failed to send event %s to %s", p.event_id, self._destination
|
||||||
)
|
)
|
||||||
|
|
||||||
self._start_catching_up()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("TX [%s] Failed to send transaction", self._destination)
|
logger.exception("TX [%s] Failed to send transaction", self._destination)
|
||||||
for p in pending_pdus:
|
for p in pending_pdus:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Failed to send event %s to %s", p.event_id, self._destination
|
"Failed to send event %s to %s", p.event_id, self._destination
|
||||||
)
|
)
|
||||||
|
|
||||||
self._start_catching_up()
|
|
||||||
finally:
|
finally:
|
||||||
# We want to be *very* sure we clear this after we stop processing
|
# We want to be *very* sure we clear this after we stop processing
|
||||||
self.transmission_loop_running = False
|
self.transmission_loop_running = False
|
||||||
|
|
@ -499,13 +419,10 @@ class PerDestinationQueue:
|
||||||
rooms = [p.room_id for p in catchup_pdus]
|
rooms = [p.room_id for p in catchup_pdus]
|
||||||
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
|
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
|
||||||
|
|
||||||
success = await self._transaction_manager.send_new_transaction(
|
await self._transaction_manager.send_new_transaction(
|
||||||
self._destination, catchup_pdus, []
|
self._destination, catchup_pdus, []
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
|
||||||
return
|
|
||||||
|
|
||||||
sent_transactions_counter.inc()
|
sent_transactions_counter.inc()
|
||||||
final_pdu = catchup_pdus[-1]
|
final_pdu = catchup_pdus[-1]
|
||||||
self._last_successful_stream_ordering = cast(
|
self._last_successful_stream_ordering = cast(
|
||||||
|
|
@ -584,3 +501,135 @@ class PerDestinationQueue:
|
||||||
"""
|
"""
|
||||||
self._catching_up = True
|
self._catching_up = True
|
||||||
self._pending_pdus = []
|
self._pending_pdus = []
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _TransactionQueueManager:
|
||||||
|
"""A helper async context manager for pulling stuff off the queues and
|
||||||
|
tracking what was last successfully sent, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
queue = attr.ib(type=PerDestinationQueue)
|
||||||
|
|
||||||
|
_device_stream_id = attr.ib(type=Optional[int], default=None)
|
||||||
|
_device_list_id = attr.ib(type=Optional[int], default=None)
|
||||||
|
_last_stream_ordering = attr.ib(type=Optional[int], default=None)
|
||||||
|
_pdus = attr.ib(type=List[EventBase], factory=list)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
|
||||||
|
# First we calculate the EDUs we want to send, if any.
|
||||||
|
|
||||||
|
# We start by fetching device related EDUs, i.e device updates and to
|
||||||
|
# device messages. We have to keep 2 free slots for presence and rr_edus.
|
||||||
|
limit = MAX_EDUS_PER_TRANSACTION - 2
|
||||||
|
|
||||||
|
device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
|
||||||
|
limit
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_update_edus:
|
||||||
|
self._device_list_id = dev_list_id
|
||||||
|
else:
|
||||||
|
self.queue._last_device_list_stream_id = dev_list_id
|
||||||
|
|
||||||
|
limit -= len(device_update_edus)
|
||||||
|
|
||||||
|
(
|
||||||
|
to_device_edus,
|
||||||
|
device_stream_id,
|
||||||
|
) = await self.queue._get_to_device_message_edus(limit)
|
||||||
|
|
||||||
|
if to_device_edus:
|
||||||
|
self._device_stream_id = device_stream_id
|
||||||
|
else:
|
||||||
|
self.queue._last_device_stream_id = device_stream_id
|
||||||
|
|
||||||
|
pending_edus = device_update_edus + to_device_edus
|
||||||
|
|
||||||
|
# Now add the read receipt EDU.
|
||||||
|
pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
|
||||||
|
|
||||||
|
# And presence EDU.
|
||||||
|
if self.queue._pending_presence:
|
||||||
|
pending_edus.append(
|
||||||
|
Edu(
|
||||||
|
origin=self.queue._server_name,
|
||||||
|
destination=self.queue._destination,
|
||||||
|
edu_type="m.presence",
|
||||||
|
content={
|
||||||
|
"push": [
|
||||||
|
format_user_presence_state(
|
||||||
|
presence, self.queue._clock.time_msec()
|
||||||
|
)
|
||||||
|
for presence in self.queue._pending_presence.values()
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.queue._pending_presence = {}
|
||||||
|
|
||||||
|
# Finally add any other types of EDUs if there is room.
|
||||||
|
pending_edus.extend(
|
||||||
|
self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
|
||||||
|
)
|
||||||
|
while (
|
||||||
|
len(pending_edus) < MAX_EDUS_PER_TRANSACTION
|
||||||
|
and self.queue._pending_edus_keyed
|
||||||
|
):
|
||||||
|
_, val = self.queue._pending_edus_keyed.popitem()
|
||||||
|
pending_edus.append(val)
|
||||||
|
|
||||||
|
# Now we look for any PDUs to send, by getting up to 50 PDUs from the
|
||||||
|
# queue
|
||||||
|
self._pdus = self.queue._pending_pdus[:50]
|
||||||
|
|
||||||
|
if not self._pdus and not pending_edus:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# if we've decided to send a transaction anyway, and we have room, we
|
||||||
|
# may as well send any pending RRs
|
||||||
|
if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
|
||||||
|
pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
|
||||||
|
|
||||||
|
if self._pdus:
|
||||||
|
self._last_stream_ordering = self._pdus[
|
||||||
|
-1
|
||||||
|
].internal_metadata.stream_ordering
|
||||||
|
assert self._last_stream_ordering
|
||||||
|
|
||||||
|
return self._pdus, pending_edus
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
if exc_type is not None:
|
||||||
|
# Failed to send transaction, so we bail out.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Successfully sent transactions, so we remove pending PDUs from the queue
|
||||||
|
if self._pdus:
|
||||||
|
self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
|
||||||
|
|
||||||
|
# Succeeded to send the transaction so we record where we have sent up
|
||||||
|
# to in the various streams
|
||||||
|
|
||||||
|
if self._device_stream_id:
|
||||||
|
await self.queue._store.delete_device_msgs_for_remote(
|
||||||
|
self.queue._destination, self._device_stream_id
|
||||||
|
)
|
||||||
|
self.queue._last_device_stream_id = self._device_stream_id
|
||||||
|
|
||||||
|
# also mark the device updates as sent
|
||||||
|
if self._device_list_id:
|
||||||
|
logger.info(
|
||||||
|
"Marking as sent %r %r", self.queue._destination, self._device_list_id
|
||||||
|
)
|
||||||
|
await self.queue._store.mark_as_sent_devices_by_remote(
|
||||||
|
self.queue._destination, self._device_list_id
|
||||||
|
)
|
||||||
|
self.queue._last_device_list_stream_id = self._device_list_id
|
||||||
|
|
||||||
|
if self._last_stream_ordering:
|
||||||
|
# we sent some PDUs and it was successful, so update our
|
||||||
|
# last_successful_stream_ordering in the destinations table.
|
||||||
|
await self.queue._store.set_destination_last_successful_stream_ordering(
|
||||||
|
self.queue._destination, self._last_stream_ordering
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -36,9 +36,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
last_pdu_age_metric = Gauge(
|
last_pdu_ts_metric = Gauge(
|
||||||
"synapse_federation_last_sent_pdu_age",
|
"synapse_federation_last_sent_pdu_time",
|
||||||
"The age (in seconds) of the last PDU successfully sent to the given domain",
|
"The timestamp of the last PDU which was successfully sent to the given domain",
|
||||||
labelnames=("server_name",),
|
labelnames=("server_name",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -69,15 +69,12 @@ class TransactionManager:
|
||||||
destination: str,
|
destination: str,
|
||||||
pdus: List[EventBase],
|
pdus: List[EventBase],
|
||||||
edus: List[Edu],
|
edus: List[Edu],
|
||||||
) -> bool:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
destination: The destination to send to (e.g. 'example.org')
|
destination: The destination to send to (e.g. 'example.org')
|
||||||
pdus: In-order list of PDUs to send
|
pdus: In-order list of PDUs to send
|
||||||
edus: List of EDUs to send
|
edus: List of EDUs to send
|
||||||
|
|
||||||
Returns:
|
|
||||||
True iff the transaction was successful
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Make a transaction-sending opentracing span. This span follows on from
|
# Make a transaction-sending opentracing span. This span follows on from
|
||||||
|
|
@ -96,8 +93,6 @@ class TransactionManager:
|
||||||
edu.strip_context()
|
edu.strip_context()
|
||||||
|
|
||||||
with start_active_span_follows_from("send_transaction", span_contexts):
|
with start_active_span_follows_from("send_transaction", span_contexts):
|
||||||
success = True
|
|
||||||
|
|
||||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||||
|
|
||||||
txn_id = str(self._next_txn_id)
|
txn_id = str(self._next_txn_id)
|
||||||
|
|
@ -152,45 +147,29 @@ class TransactionManager:
|
||||||
response = await self._transport_layer.send_transaction(
|
response = await self._transport_layer.send_transaction(
|
||||||
transaction, json_data_cb
|
transaction, json_data_cb
|
||||||
)
|
)
|
||||||
code = 200
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
response = e.response
|
response = e.response
|
||||||
|
|
||||||
if e.code in (401, 404, 429) or 500 <= e.code:
|
set_tag(tags.ERROR, True)
|
||||||
logger.info(
|
|
||||||
"TX [%s] {%s} got %d response", destination, txn_id, code
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
|
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
|
||||||
|
raise
|
||||||
|
|
||||||
if code == 200:
|
logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
|
||||||
for e_id, r in response.get("pdus", {}).items():
|
|
||||||
if "error" in r:
|
for e_id, r in response.get("pdus", {}).items():
|
||||||
logger.warning(
|
if "error" in r:
|
||||||
"TX [%s] {%s} Remote returned error for %s: %s",
|
|
||||||
destination,
|
|
||||||
txn_id,
|
|
||||||
e_id,
|
|
||||||
r,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for p in pdus:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"TX [%s] {%s} Failed to send event %s",
|
"TX [%s] {%s} Remote returned error for %s: %s",
|
||||||
destination,
|
destination,
|
||||||
txn_id,
|
txn_id,
|
||||||
p.event_id,
|
e_id,
|
||||||
|
r,
|
||||||
)
|
)
|
||||||
success = False
|
|
||||||
|
|
||||||
if success and pdus and destination in self._federation_metrics_domains:
|
if pdus and destination in self._federation_metrics_domains:
|
||||||
last_pdu = pdus[-1]
|
last_pdu = pdus[-1]
|
||||||
last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
|
last_pdu_ts_metric.labels(server_name=destination).set(
|
||||||
last_pdu_age_metric.labels(server_name=destination).set(
|
last_pdu.origin_server_ts / 1000
|
||||||
last_pdu_age / 1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
set_tag(tags.ERROR, not success)
|
|
||||||
return success
|
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,9 @@ class AcmeHandler:
|
||||||
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
|
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
|
self.reactor.listenTCP(
|
||||||
|
self.hs.config.acme_port, srv, backlog=50, interface=host
|
||||||
|
)
|
||||||
except twisted.internet.error.CannotListenError as e:
|
except twisted.internet.error.CannotListenError as e:
|
||||||
check_bind_error(e, host, bind_addresses)
|
check_bind_error(e, host, bind_addresses)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util import stringutils as stringutils
|
from synapse.util import stringutils as stringutils
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
|
|
@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
|
||||||
extra_attributes = attr.ib(type=JsonDict)
|
extra_attributes = attr.ib(type=JsonDict)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class LoginTokenAttributes:
|
||||||
|
"""Data we store in a short-term login token"""
|
||||||
|
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# the SSO Identity Provider that the user authenticated with, to get this token
|
||||||
|
auth_provider_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
|
@ -326,7 +337,8 @@ class AuthHandler(BaseHandler):
|
||||||
user is too high to proceed
|
user is too high to proceed
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not requester.access_token_id:
|
||||||
|
raise ValueError("Cannot validate a user without an access token")
|
||||||
if self._ui_auth_session_timeout:
|
if self._ui_auth_session_timeout:
|
||||||
last_validated = await self.store.get_access_token_last_validated(
|
last_validated = await self.store.get_access_token_last_validated(
|
||||||
requester.access_token_id
|
requester.access_token_id
|
||||||
|
|
@ -1164,18 +1176,16 @@ class AuthHandler(BaseHandler):
|
||||||
return None
|
return None
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
|
async def validate_short_term_login_token(
|
||||||
auth_api = self.hs.get_auth()
|
self, login_token: str
|
||||||
user_id = None
|
) -> LoginTokenAttributes:
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
res = self.macaroon_gen.verify_short_term_login_token(login_token)
|
||||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
|
||||||
auth_api.validate_macaroon(macaroon, "login", user_id)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(res.user_id)
|
||||||
return user_id
|
return res
|
||||||
|
|
||||||
async def delete_access_token(self, access_token: str):
|
async def delete_access_token(self, access_token: str):
|
||||||
"""Invalidate a single access token
|
"""Invalidate a single access token
|
||||||
|
|
@ -1204,7 +1214,7 @@ class AuthHandler(BaseHandler):
|
||||||
async def delete_access_tokens_for_user(
|
async def delete_access_tokens_for_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
except_token_id: Optional[str] = None,
|
except_token_id: Optional[int] = None,
|
||||||
device_id: Optional[str] = None,
|
device_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Invalidate access tokens belonging to a user
|
"""Invalidate access tokens belonging to a user
|
||||||
|
|
@ -1397,6 +1407,7 @@ class AuthHandler(BaseHandler):
|
||||||
async def complete_sso_login(
|
async def complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
|
|
@ -1406,6 +1417,9 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registered_user_id: The registered user ID to complete SSO login for.
|
registered_user_id: The registered user ID to complete SSO login for.
|
||||||
|
auth_provider_id: The id of the SSO Identity provider that was used for
|
||||||
|
login. This will be stored in the login token for future tracking in
|
||||||
|
prometheus metrics.
|
||||||
request: The request to complete.
|
request: The request to complete.
|
||||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||||
process.
|
process.
|
||||||
|
|
@ -1427,6 +1441,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
self._complete_sso_login(
|
self._complete_sso_login(
|
||||||
registered_user_id,
|
registered_user_id,
|
||||||
|
auth_provider_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
extra_attributes,
|
extra_attributes,
|
||||||
|
|
@ -1437,6 +1452,7 @@ class AuthHandler(BaseHandler):
|
||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
|
|
@ -1463,7 +1479,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||||
registered_user_id
|
registered_user_id, auth_provider_id=auth_provider_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append the login token to the original redirect URL (i.e. with its query
|
# Append the login token to the original redirect URL (i.e. with its query
|
||||||
|
|
@ -1569,15 +1585,48 @@ class MacaroonGenerator:
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
def generate_short_term_login_token(
|
||||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
self,
|
||||||
|
user_id: str,
|
||||||
|
auth_provider_id: str,
|
||||||
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
) -> str:
|
) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
||||||
|
"""Verify a short-term-login macaroon
|
||||||
|
|
||||||
|
Checks that the given token is a valid, unexpired short-term-login token
|
||||||
|
minted by this server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: the login token to verify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the user_id that this token is valid for
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MacaroonVerificationFailedException if the verification failed
|
||||||
|
"""
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||||
|
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
||||||
|
|
||||||
|
v = pymacaroons.Verifier()
|
||||||
|
v.satisfy_exact("gen = 1")
|
||||||
|
v.satisfy_exact("type = login")
|
||||||
|
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
||||||
|
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
||||||
|
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||||
|
|
||||||
|
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
|
||||||
|
|
||||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class CasHandler:
|
||||||
# the SsoIdentityProvider protocol type.
|
# the SsoIdentityProvider protocol type.
|
||||||
self.idp_icon = None
|
self.idp_icon = None
|
||||||
self.idp_brand = None
|
self.idp_brand = None
|
||||||
|
self.unstable_idp_brand = None
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ class FederationHandler(BaseHandler):
|
||||||
or pdu.internal_metadata.is_outlier()
|
or pdu.internal_metadata.is_outlier()
|
||||||
)
|
)
|
||||||
if already_seen:
|
if already_seen:
|
||||||
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
|
logger.debug("Already seen pdu")
|
||||||
return
|
return
|
||||||
|
|
||||||
# do some initial sanity-checking of the event. In particular, make
|
# do some initial sanity-checking of the event. In particular, make
|
||||||
|
|
@ -210,18 +210,14 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
self._sanity_check_event(pdu)
|
self._sanity_check_event(pdu)
|
||||||
except SynapseError as err:
|
except SynapseError as err:
|
||||||
logger.warning(
|
logger.warning("Received event failed sanity checks")
|
||||||
"[%s %s] Received event failed sanity checks", room_id, event_id
|
|
||||||
)
|
|
||||||
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
|
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
|
||||||
|
|
||||||
# If we are currently in the process of joining this room, then we
|
# If we are currently in the process of joining this room, then we
|
||||||
# queue up events for later processing.
|
# queue up events for later processing.
|
||||||
if room_id in self.room_queues:
|
if room_id in self.room_queues:
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Queuing PDU from %s for now: join in progress",
|
"Queuing PDU from %s for now: join in progress",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
origin,
|
origin,
|
||||||
)
|
)
|
||||||
self.room_queues[room_id].append((pdu, origin))
|
self.room_queues[room_id].append((pdu, origin))
|
||||||
|
|
@ -236,9 +232,7 @@ class FederationHandler(BaseHandler):
|
||||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||||
if not is_in_room:
|
if not is_in_room:
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
"Ignoring PDU from %s as we're not in the room",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
origin,
|
origin,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
@ -250,7 +244,7 @@ class FederationHandler(BaseHandler):
|
||||||
# We only backfill backwards to the min depth.
|
# We only backfill backwards to the min depth.
|
||||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||||
|
|
||||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
logger.debug("min_depth: %d", min_depth)
|
||||||
|
|
||||||
prevs = set(pdu.prev_event_ids())
|
prevs = set(pdu.prev_event_ids())
|
||||||
seen = await self.store.have_events_in_timeline(prevs)
|
seen = await self.store.have_events_in_timeline(prevs)
|
||||||
|
|
@ -267,17 +261,13 @@ class FederationHandler(BaseHandler):
|
||||||
# If we're missing stuff, ensure we only fetch stuff one
|
# If we're missing stuff, ensure we only fetch stuff one
|
||||||
# at a time.
|
# at a time.
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
|
"Acquiring room lock to fetch %d missing prev_events: %s",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
len(missing_prevs),
|
len(missing_prevs),
|
||||||
shortstr(missing_prevs),
|
shortstr(missing_prevs),
|
||||||
)
|
)
|
||||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
"Acquired room lock to fetch %d missing prev_events",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
len(missing_prevs),
|
len(missing_prevs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -297,9 +287,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if not prevs - seen:
|
if not prevs - seen:
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Found all missing prev_events",
|
"Found all missing prev_events",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if prevs - seen:
|
if prevs - seen:
|
||||||
|
|
@ -329,9 +317,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if sent_to_us_directly:
|
if sent_to_us_directly:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
|
"Rejecting: failed to fetch %d prev events: %s",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
len(prevs - seen),
|
len(prevs - seen),
|
||||||
shortstr(prevs - seen),
|
shortstr(prevs - seen),
|
||||||
)
|
)
|
||||||
|
|
@ -367,17 +353,16 @@ class FederationHandler(BaseHandler):
|
||||||
# Ask the remote server for the states we don't
|
# Ask the remote server for the states we don't
|
||||||
# know about
|
# know about
|
||||||
for p in prevs - seen:
|
for p in prevs - seen:
|
||||||
logger.info(
|
logger.info("Requesting state after missing prev_event %s", p)
|
||||||
"Requesting state at missing prev_event %s",
|
|
||||||
event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
with nested_logging_context(p):
|
with nested_logging_context(p):
|
||||||
# note that if any of the missing prevs share missing state or
|
# note that if any of the missing prevs share missing state or
|
||||||
# auth events, the requests to fetch those events are deduped
|
# auth events, the requests to fetch those events are deduped
|
||||||
# by the get_pdu_cache in federation_client.
|
# by the get_pdu_cache in federation_client.
|
||||||
(remote_state, _,) = await self._get_state_for_room(
|
remote_state = (
|
||||||
origin, room_id, p, include_event_in_state=True
|
await self._get_state_after_missing_prev_event(
|
||||||
|
origin, room_id, p
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_state_map = {
|
remote_state_map = {
|
||||||
|
|
@ -414,10 +399,7 @@ class FederationHandler(BaseHandler):
|
||||||
state = [event_map[e] for e in state_map.values()]
|
state = [event_map[e] for e in state_map.values()]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[%s %s] Error attempting to resolve state at missing "
|
"Error attempting to resolve state at missing " "prev_events",
|
||||||
"prev_events",
|
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise FederationError(
|
raise FederationError(
|
||||||
|
|
@ -454,9 +436,7 @@ class FederationHandler(BaseHandler):
|
||||||
latest |= seen
|
latest |= seen
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s]: Requesting missing events between %s and %s",
|
"Requesting missing events between %s and %s",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
shortstr(latest),
|
shortstr(latest),
|
||||||
event_id,
|
event_id,
|
||||||
)
|
)
|
||||||
|
|
@ -523,15 +503,11 @@ class FederationHandler(BaseHandler):
|
||||||
# We failed to get the missing events, but since we need to handle
|
# We failed to get the missing events, but since we need to handle
|
||||||
# the case of `get_missing_events` not returning the necessary
|
# the case of `get_missing_events` not returning the necessary
|
||||||
# events anyway, it is safe to simply log the error and continue.
|
# events anyway, it is safe to simply log the error and continue.
|
||||||
logger.warning(
|
logger.warning("Failed to get prev_events: %s", e)
|
||||||
"[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s]: Got %d prev_events: %s",
|
"Got %d prev_events: %s",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
len(missing_events),
|
len(missing_events),
|
||||||
shortstr(missing_events),
|
shortstr(missing_events),
|
||||||
)
|
)
|
||||||
|
|
@ -542,9 +518,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
for ev in missing_events:
|
for ev in missing_events:
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s %s] Handling received prev_event %s",
|
"Handling received prev_event %s",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
ev.event_id,
|
ev.event_id,
|
||||||
)
|
)
|
||||||
with nested_logging_context(ev.event_id):
|
with nested_logging_context(ev.event_id):
|
||||||
|
|
@ -553,9 +527,7 @@ class FederationHandler(BaseHandler):
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
if e.code == 403:
|
if e.code == 403:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[%s %s] Received prev_event %s failed history check.",
|
"Received prev_event %s failed history check.",
|
||||||
room_id,
|
|
||||||
event_id,
|
|
||||||
ev.event_id,
|
ev.event_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -566,7 +538,6 @@ class FederationHandler(BaseHandler):
|
||||||
destination: str,
|
destination: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
include_event_in_state: bool = False,
|
|
||||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||||
"""Requests all of the room state at a given event from a remote homeserver.
|
"""Requests all of the room state at a given event from a remote homeserver.
|
||||||
|
|
||||||
|
|
@ -574,11 +545,9 @@ class FederationHandler(BaseHandler):
|
||||||
destination: The remote homeserver to query for the state.
|
destination: The remote homeserver to query for the state.
|
||||||
room_id: The id of the room we're interested in.
|
room_id: The id of the room we're interested in.
|
||||||
event_id: The id of the event we want the state at.
|
event_id: The id of the event we want the state at.
|
||||||
include_event_in_state: if true, the event itself will be included in the
|
|
||||||
returned state event list.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of events in the state, possibly including the event itself, and
|
A list of events in the state, not including the event itself, and
|
||||||
a list of events in the auth chain for the given event.
|
a list of events in the auth chain for the given event.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
|
|
@ -590,9 +559,6 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
desired_events = set(state_event_ids + auth_event_ids)
|
desired_events = set(state_event_ids + auth_event_ids)
|
||||||
|
|
||||||
if include_event_in_state:
|
|
||||||
desired_events.add(event_id)
|
|
||||||
|
|
||||||
event_map = await self._get_events_from_store_or_dest(
|
event_map = await self._get_events_from_store_or_dest(
|
||||||
destination, room_id, desired_events
|
destination, room_id, desired_events
|
||||||
)
|
)
|
||||||
|
|
@ -609,13 +575,6 @@ class FederationHandler(BaseHandler):
|
||||||
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
||||||
]
|
]
|
||||||
|
|
||||||
if include_event_in_state:
|
|
||||||
remote_event = event_map.get(event_id)
|
|
||||||
if not remote_event:
|
|
||||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
|
||||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
|
||||||
remote_state.append(remote_event)
|
|
||||||
|
|
||||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
|
|
@ -689,6 +648,131 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
return fetched_events
|
return fetched_events
|
||||||
|
|
||||||
|
async def _get_state_after_missing_prev_event(
|
||||||
|
self,
|
||||||
|
destination: str,
|
||||||
|
room_id: str,
|
||||||
|
event_id: str,
|
||||||
|
) -> List[EventBase]:
|
||||||
|
"""Requests all of the room state at a given event from a remote homeserver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination: The remote homeserver to query for the state.
|
||||||
|
room_id: The id of the room we're interested in.
|
||||||
|
event_id: The id of the event we want the state at.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of events in the state, including the event itself
|
||||||
|
"""
|
||||||
|
# TODO: This function is basically the same as _get_state_for_room. Can
|
||||||
|
# we make backfill() use it, rather than having two code paths? I think the
|
||||||
|
# only difference is that backfill() persists the prev events separately.
|
||||||
|
|
||||||
|
(
|
||||||
|
state_event_ids,
|
||||||
|
auth_event_ids,
|
||||||
|
) = await self.federation_client.get_room_state_ids(
|
||||||
|
destination, room_id, event_id=event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"state_ids returned %i state events, %i auth events",
|
||||||
|
len(state_event_ids),
|
||||||
|
len(auth_event_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
# start by just trying to fetch the events from the store
|
||||||
|
desired_events = set(state_event_ids)
|
||||||
|
desired_events.add(event_id)
|
||||||
|
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
||||||
|
fetched_events = await self.store.get_events(
|
||||||
|
desired_events, allow_rejected=True
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_desired_events = desired_events - fetched_events.keys()
|
||||||
|
logger.debug(
|
||||||
|
"We are missing %i events (got %i)",
|
||||||
|
len(missing_desired_events),
|
||||||
|
len(fetched_events),
|
||||||
|
)
|
||||||
|
|
||||||
|
# We probably won't need most of the auth events, so let's just check which
|
||||||
|
# we have for now, rather than thrashing the event cache with them all
|
||||||
|
# unnecessarily.
|
||||||
|
|
||||||
|
# TODO: we probably won't actually need all of the auth events, since we
|
||||||
|
# already have a bunch of the state events. It would be nice if the
|
||||||
|
# federation api gave us a way of finding out which we actually need.
|
||||||
|
|
||||||
|
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
||||||
|
missing_auth_events.difference_update(
|
||||||
|
await self.store.have_seen_events(missing_auth_events)
|
||||||
|
)
|
||||||
|
logger.debug("We are also missing %i auth events", len(missing_auth_events))
|
||||||
|
|
||||||
|
missing_events = missing_desired_events | missing_auth_events
|
||||||
|
logger.debug("Fetching %i events from remote", len(missing_events))
|
||||||
|
await self._get_events_and_persist(
|
||||||
|
destination=destination, room_id=room_id, events=missing_events
|
||||||
|
)
|
||||||
|
|
||||||
|
# we need to make sure we re-load from the database to get the rejected
|
||||||
|
# state correct.
|
||||||
|
fetched_events.update(
|
||||||
|
(await self.store.get_events(missing_desired_events, allow_rejected=True))
|
||||||
|
)
|
||||||
|
|
||||||
|
# check for events which were in the wrong room.
|
||||||
|
#
|
||||||
|
# this can happen if a remote server claims that the state or
|
||||||
|
# auth_events at an event in room A are actually events in room B
|
||||||
|
|
||||||
|
bad_events = [
|
||||||
|
(event_id, event.room_id)
|
||||||
|
for event_id, event in fetched_events.items()
|
||||||
|
if event.room_id != room_id
|
||||||
|
]
|
||||||
|
|
||||||
|
for bad_event_id, bad_room_id in bad_events:
|
||||||
|
# This is a bogus situation, but since we may only discover it a long time
|
||||||
|
# after it happened, we try our best to carry on, by just omitting the
|
||||||
|
# bad events from the returned state set.
|
||||||
|
logger.warning(
|
||||||
|
"Remote server %s claims event %s in room %s is an auth/state "
|
||||||
|
"event in room %s",
|
||||||
|
destination,
|
||||||
|
bad_event_id,
|
||||||
|
bad_room_id,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
del fetched_events[bad_event_id]
|
||||||
|
|
||||||
|
# if we couldn't get the prev event in question, that's a problem.
|
||||||
|
remote_event = fetched_events.get(event_id)
|
||||||
|
if not remote_event:
|
||||||
|
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||||
|
|
||||||
|
# missing state at that event is a warning, not a blocker
|
||||||
|
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||||
|
# state.
|
||||||
|
failed_to_fetch = desired_events - fetched_events.keys()
|
||||||
|
if failed_to_fetch:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to fetch missing state events for %s %s",
|
||||||
|
event_id,
|
||||||
|
failed_to_fetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_state = [
|
||||||
|
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
|
||||||
|
]
|
||||||
|
|
||||||
|
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||||
|
remote_state.append(remote_event)
|
||||||
|
|
||||||
|
return remote_state
|
||||||
|
|
||||||
async def _process_received_pdu(
|
async def _process_received_pdu(
|
||||||
self,
|
self,
|
||||||
origin: str,
|
origin: str,
|
||||||
|
|
@ -707,10 +791,7 @@ class FederationHandler(BaseHandler):
|
||||||
(ie, we are missing one or more prev_events), the resolved state at the
|
(ie, we are missing one or more prev_events), the resolved state at the
|
||||||
event
|
event
|
||||||
"""
|
"""
|
||||||
room_id = event.room_id
|
logger.debug("Processing event: %s", event)
|
||||||
event_id = event.event_id
|
|
||||||
|
|
||||||
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._handle_new_event(origin, event, state=state)
|
await self._handle_new_event(origin, event, state=state)
|
||||||
|
|
@ -871,7 +952,6 @@ class FederationHandler(BaseHandler):
|
||||||
destination=dest,
|
destination=dest,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_id=e_id,
|
event_id=e_id,
|
||||||
include_event_in_state=False,
|
|
||||||
)
|
)
|
||||||
auth_events.update({a.event_id: a for a in auth})
|
auth_events.update({a.event_id: a for a in auth})
|
||||||
auth_events.update({s.event_id: s for s in state})
|
auth_events.update({s.event_id: s for s in state})
|
||||||
|
|
@ -1317,7 +1397,7 @@ class FederationHandler(BaseHandler):
|
||||||
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
||||||
event = await self.store.get_event(event_id)
|
event = await self.store.get_event(event_id)
|
||||||
auth = await self.store.get_auth_chain(
|
auth = await self.store.get_auth_chain(
|
||||||
list(event.auth_event_ids()), include_given=True
|
event.room_id, list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
return list(auth)
|
return list(auth)
|
||||||
|
|
||||||
|
|
@ -1580,7 +1660,7 @@ class FederationHandler(BaseHandler):
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
|
||||||
state_ids = list(prev_state_ids.values())
|
state_ids = list(prev_state_ids.values())
|
||||||
auth_chain = await self.store.get_auth_chain(state_ids)
|
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
|
||||||
|
|
||||||
state = await self.store.get_events(list(prev_state_ids.values()))
|
state = await self.store.get_events(list(prev_state_ids.values()))
|
||||||
|
|
||||||
|
|
@ -2219,7 +2299,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
# Now get the current auth_chain for the event.
|
||||||
local_auth_chain = await self.store.get_auth_chain(
|
local_auth_chain = await self.store.get_auth_chain(
|
||||||
list(event.auth_event_ids()), include_given=True
|
room_id, list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
self.snapshot_cache = ResponseCache(
|
self.snapshot_cache = ResponseCache(
|
||||||
hs, "initial_sync_cache"
|
hs.get_clock(), "initial_sync_cache"
|
||||||
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
|
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2020 Quentin Gliech
|
# Copyright 2020 Quentin Gliech
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
@ -14,13 +15,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
from authlib.common.security import generate_token
|
from authlib.common.security import generate_token
|
||||||
from authlib.jose import JsonWebToken
|
from authlib.jose import JsonWebToken, jwt
|
||||||
from authlib.oauth2.auth import ClientAuth
|
from authlib.oauth2.auth import ClientAuth
|
||||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
||||||
|
|
@ -28,20 +29,26 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||||
from jinja2 import Environment, Template
|
from jinja2 import Environment, Template
|
||||||
from pymacaroons.exceptions import (
|
from pymacaroons.exceptions import (
|
||||||
MacaroonDeserializationException,
|
MacaroonDeserializationException,
|
||||||
|
MacaroonInitException,
|
||||||
MacaroonInvalidSignatureException,
|
MacaroonInvalidSignatureException,
|
||||||
)
|
)
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.oidc_config import OidcProviderConfig
|
from synapse.config.oidc_config import (
|
||||||
|
OidcProviderClientSecretJwtKey,
|
||||||
|
OidcProviderConfig,
|
||||||
|
)
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import Clock, json_decoder
|
||||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
@ -211,7 +218,7 @@ class OidcHandler:
|
||||||
session_data = self._token_generator.verify_oidc_session_token(
|
session_data = self._token_generator.verify_oidc_session_token(
|
||||||
session, state
|
session, state
|
||||||
)
|
)
|
||||||
except (MacaroonDeserializationException, ValueError) as e:
|
except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
|
||||||
logger.exception("Invalid session for OIDC callback")
|
logger.exception("Invalid session for OIDC callback")
|
||||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
return
|
return
|
||||||
|
|
@ -275,9 +282,21 @@ class OidcProvider:
|
||||||
|
|
||||||
self._scopes = provider.scopes
|
self._scopes = provider.scopes
|
||||||
self._user_profile_method = provider.user_profile_method
|
self._user_profile_method = provider.user_profile_method
|
||||||
|
|
||||||
|
client_secret = None # type: Union[None, str, JwtClientSecret]
|
||||||
|
if provider.client_secret:
|
||||||
|
client_secret = provider.client_secret
|
||||||
|
elif provider.client_secret_jwt_key:
|
||||||
|
client_secret = JwtClientSecret(
|
||||||
|
provider.client_secret_jwt_key,
|
||||||
|
provider.client_id,
|
||||||
|
provider.issuer,
|
||||||
|
hs.get_clock(),
|
||||||
|
)
|
||||||
|
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
provider.client_id,
|
provider.client_id,
|
||||||
provider.client_secret,
|
client_secret,
|
||||||
provider.client_auth_method,
|
provider.client_auth_method,
|
||||||
) # type: ClientAuth
|
) # type: ClientAuth
|
||||||
self._client_auth_method = provider.client_auth_method
|
self._client_auth_method = provider.client_auth_method
|
||||||
|
|
@ -312,6 +331,9 @@ class OidcProvider:
|
||||||
# optional brand identifier for this auth provider
|
# optional brand identifier for this auth provider
|
||||||
self.idp_brand = provider.idp_brand
|
self.idp_brand = provider.idp_brand
|
||||||
|
|
||||||
|
# Optional brand identifier for the unstable API (see MSC2858).
|
||||||
|
self.unstable_idp_brand = provider.unstable_idp_brand
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
@ -521,7 +543,7 @@ class OidcProvider:
|
||||||
"""
|
"""
|
||||||
metadata = await self.load_metadata()
|
metadata = await self.load_metadata()
|
||||||
token_endpoint = metadata.get("token_endpoint")
|
token_endpoint = metadata.get("token_endpoint")
|
||||||
headers = {
|
raw_headers = {
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
"User-Agent": self._http_client.user_agent,
|
"User-Agent": self._http_client.user_agent,
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
|
@ -535,10 +557,10 @@ class OidcProvider:
|
||||||
body = urlencode(args, True)
|
body = urlencode(args, True)
|
||||||
|
|
||||||
# Fill the body/headers with credentials
|
# Fill the body/headers with credentials
|
||||||
uri, headers, body = self._client_auth.prepare(
|
uri, raw_headers, body = self._client_auth.prepare(
|
||||||
method="POST", uri=token_endpoint, headers=headers, body=body
|
method="POST", uri=token_endpoint, headers=raw_headers, body=body
|
||||||
)
|
)
|
||||||
headers = {k: [v] for (k, v) in headers.items()}
|
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||||
|
|
||||||
# Do the actual request
|
# Do the actual request
|
||||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||||
|
|
@ -745,7 +767,7 @@ class OidcProvider:
|
||||||
idp_id=self.idp_id,
|
idp_id=self.idp_id,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url.decode(),
|
client_redirect_url=client_redirect_url.decode(),
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id or "",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -976,6 +998,81 @@ class OidcProvider:
|
||||||
return str(remote_user_id)
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# number of seconds a newly-generated client secret should be valid for
|
||||||
|
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
||||||
|
|
||||||
|
# minimum remaining validity on a client secret before we should generate a new one
|
||||||
|
CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
|
||||||
|
|
||||||
|
|
||||||
|
class JwtClientSecret:
|
||||||
|
"""A class which generates a new client secret on demand, based on a JWK
|
||||||
|
|
||||||
|
This implementation is designed to comply with the requirements for Apple Sign in:
|
||||||
|
https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
|
||||||
|
|
||||||
|
It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
|
||||||
|
but it's worth noting that we still put the generated secret in the "client_secret"
|
||||||
|
field (or rather, whereever client_auth_method puts it) rather than in a
|
||||||
|
client_assertion field in the body as that RFC seems to require.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: OidcProviderClientSecretJwtKey,
|
||||||
|
oauth_client_id: str,
|
||||||
|
oauth_issuer: str,
|
||||||
|
clock: Clock,
|
||||||
|
):
|
||||||
|
self._key = key
|
||||||
|
self._oauth_client_id = oauth_client_id
|
||||||
|
self._oauth_issuer = oauth_issuer
|
||||||
|
self._clock = clock
|
||||||
|
self._cached_secret = b""
|
||||||
|
self._cached_secret_replacement_time = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
|
||||||
|
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
|
||||||
|
# here.
|
||||||
|
return self._get_secret().decode("ascii")
|
||||||
|
|
||||||
|
def __bytes__(self):
|
||||||
|
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
|
||||||
|
# encode_client_secret_post, which ends up here.
|
||||||
|
return self._get_secret()
|
||||||
|
|
||||||
|
def _get_secret(self) -> bytes:
|
||||||
|
now = self._clock.time()
|
||||||
|
|
||||||
|
# if we have enough validity on our existing secret, use it
|
||||||
|
if now < self._cached_secret_replacement_time:
|
||||||
|
return self._cached_secret
|
||||||
|
|
||||||
|
issued_at = int(now)
|
||||||
|
expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
|
||||||
|
|
||||||
|
# we copy the configured header because jwt.encode modifies it.
|
||||||
|
header = dict(self._key.jwt_header)
|
||||||
|
|
||||||
|
# see https://tools.ietf.org/html/rfc7523#section-3
|
||||||
|
payload = {
|
||||||
|
"sub": self._oauth_client_id,
|
||||||
|
"aud": self._oauth_issuer,
|
||||||
|
"iat": issued_at,
|
||||||
|
"exp": expires_at,
|
||||||
|
**self._key.jwt_payload,
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
||||||
|
)
|
||||||
|
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
||||||
|
self._cached_secret_replacement_time = (
|
||||||
|
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
||||||
|
)
|
||||||
|
return self._cached_secret
|
||||||
|
|
||||||
|
|
||||||
class OidcSessionTokenGenerator:
|
class OidcSessionTokenGenerator:
|
||||||
"""Methods for generating and checking OIDC Session cookies."""
|
"""Methods for generating and checking OIDC Session cookies."""
|
||||||
|
|
||||||
|
|
@ -1020,10 +1117,9 @@ class OidcSessionTokenGenerator:
|
||||||
macaroon.add_first_party_caveat(
|
macaroon.add_first_party_caveat(
|
||||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||||
)
|
)
|
||||||
if session_data.ui_auth_session_id:
|
macaroon.add_first_party_caveat(
|
||||||
macaroon.add_first_party_caveat(
|
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||||
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
)
|
||||||
)
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
|
@ -1046,7 +1142,7 @@ class OidcSessionTokenGenerator:
|
||||||
The data extracted from the session cookie
|
The data extracted from the session cookie
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError if an expected caveat is missing from the macaroon.
|
KeyError if an expected caveat is missing from the macaroon.
|
||||||
"""
|
"""
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||||
|
|
||||||
|
|
@ -1057,26 +1153,16 @@ class OidcSessionTokenGenerator:
|
||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
|
||||||
# to always satisfy this.
|
|
||||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||||
v.satisfy_general(self._verify_expiry)
|
satisfy_expiry(v, self._clock.time_msec)
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
# Extract the session data from the token.
|
# Extract the session data from the token.
|
||||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||||
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
idp_id = get_value_from_macaroon(macaroon, "idp_id")
|
||||||
client_redirect_url = self._get_value_from_macaroon(
|
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||||
macaroon, "client_redirect_url"
|
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
|
||||||
)
|
|
||||||
try:
|
|
||||||
ui_auth_session_id = self._get_value_from_macaroon(
|
|
||||||
macaroon, "ui_auth_session_id"
|
|
||||||
) # type: Optional[str]
|
|
||||||
except ValueError:
|
|
||||||
ui_auth_session_id = None
|
|
||||||
|
|
||||||
return OidcSessionData(
|
return OidcSessionData(
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
idp_id=idp_id,
|
idp_id=idp_id,
|
||||||
|
|
@ -1084,33 +1170,6 @@ class OidcSessionTokenGenerator:
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
|
||||||
"""Extracts a caveat value from a macaroon token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
macaroon: the token
|
|
||||||
key: the key of the caveat to extract
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The extracted value
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if the caveat was not in the macaroon
|
|
||||||
"""
|
|
||||||
prefix = key + " = "
|
|
||||||
for caveat in macaroon.caveats:
|
|
||||||
if caveat.caveat_id.startswith(prefix):
|
|
||||||
return caveat.caveat_id[len(prefix) :]
|
|
||||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
|
||||||
|
|
||||||
def _verify_expiry(self, caveat: str) -> bool:
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix) :])
|
|
||||||
now = self._clock.time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True)
|
@attr.s(frozen=True, slots=True)
|
||||||
class OidcSessionData:
|
class OidcSessionData:
|
||||||
|
|
@ -1125,8 +1184,8 @@ class OidcSessionData:
|
||||||
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||||
client_redirect_url = attr.ib(type=str)
|
client_redirect_url = attr.ib(type=str)
|
||||||
|
|
||||||
# The session ID of the ongoing UI Auth (None if this is a login)
|
# The session ID of the ongoing UI Auth ("" if this is a login)
|
||||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
ui_auth_session_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
UserAttributeDict = TypedDict(
|
UserAttributeDict = TypedDict(
|
||||||
|
|
|
||||||
|
|
@ -285,7 +285,7 @@ class PaginationHandler:
|
||||||
except Exception:
|
except Exception:
|
||||||
f = Failure()
|
f = Failure()
|
||||||
logger.error(
|
logger.error(
|
||||||
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
|
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
|
||||||
)
|
)
|
||||||
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
|
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||||
|
|
@ -41,6 +43,19 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
registration_counter = Counter(
|
||||||
|
"synapse_user_registrations_total",
|
||||||
|
"Number of new users registered (since restart)",
|
||||||
|
["guest", "shadow_banned", "auth_provider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
login_counter = Counter(
|
||||||
|
"synapse_user_logins_total",
|
||||||
|
"Number of user logins (since restart)",
|
||||||
|
["guest", "auth_provider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationHandler(BaseHandler):
|
class RegistrationHandler(BaseHandler):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
@ -67,6 +82,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
self._register_device_client = self.register_device_inner
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
self.session_lifetime = hs.config.session_lifetime
|
self.session_lifetime = hs.config.session_lifetime
|
||||||
|
|
@ -161,6 +177,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
bind_emails: Iterable[str] = [],
|
bind_emails: Iterable[str] = [],
|
||||||
by_admin: bool = False,
|
by_admin: bool = False,
|
||||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
|
@ -186,8 +203,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
admin api, otherwise False.
|
admin api, otherwise False.
|
||||||
user_agent_ips: Tuples of IP addresses and user-agents used
|
user_agent_ips: Tuples of IP addresses and user-agents used
|
||||||
during the registration process.
|
during the registration process.
|
||||||
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
Returns:
|
Returns:
|
||||||
The registere user_id.
|
The registered user_id.
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
|
|
@ -197,6 +215,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
threepid,
|
threepid,
|
||||||
localpart,
|
localpart,
|
||||||
user_agent_ips or [],
|
user_agent_ips or [],
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == RegistrationBehaviour.DENY:
|
if result == RegistrationBehaviour.DENY:
|
||||||
|
|
@ -287,6 +306,12 @@ class RegistrationHandler(BaseHandler):
|
||||||
# if user id is taken, just generate another
|
# if user id is taken, just generate another
|
||||||
fail_count += 1
|
fail_count += 1
|
||||||
|
|
||||||
|
registration_counter.labels(
|
||||||
|
guest=make_guest,
|
||||||
|
shadow_banned=shadow_banned,
|
||||||
|
auth_provider=(auth_provider_id or ""),
|
||||||
|
).inc()
|
||||||
|
|
||||||
if not self.hs.config.user_consent_at_registration:
|
if not self.hs.config.user_consent_at_registration:
|
||||||
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
|
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -645,6 +670,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
initial_display_name: Optional[str],
|
initial_display_name: Optional[str],
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
|
|
@ -655,21 +681,40 @@ class RegistrationHandler(BaseHandler):
|
||||||
device_id: The device ID to check, or None to generate a new one.
|
device_id: The device ID to check, or None to generate a new one.
|
||||||
initial_display_name: An optional display name for the device.
|
initial_display_name: An optional display name for the device.
|
||||||
is_guest: Whether this is a guest account
|
is_guest: Whether this is a guest account
|
||||||
|
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||||
|
prometheus metrics).
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of device ID and access token
|
Tuple of device ID and access token
|
||||||
"""
|
"""
|
||||||
|
res = await self._register_device_client(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
initial_display_name=initial_display_name,
|
||||||
|
is_guest=is_guest,
|
||||||
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
|
)
|
||||||
|
|
||||||
if self.hs.config.worker_app:
|
login_counter.labels(
|
||||||
r = await self._register_device_client(
|
guest=is_guest,
|
||||||
user_id=user_id,
|
auth_provider=(auth_provider_id or ""),
|
||||||
device_id=device_id,
|
).inc()
|
||||||
initial_display_name=initial_display_name,
|
|
||||||
is_guest=is_guest,
|
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
|
||||||
)
|
|
||||||
return r["device_id"], r["access_token"]
|
|
||||||
|
|
||||||
|
return res["device_id"], res["access_token"]
|
||||||
|
|
||||||
|
async def register_device_inner(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
is_guest: bool = False,
|
||||||
|
is_appservice_ghost: bool = False,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Helper for register_device
|
||||||
|
|
||||||
|
Does the bits that need doing on the main process. Not for use outside this
|
||||||
|
class and RegisterDeviceReplicationServlet.
|
||||||
|
"""
|
||||||
|
assert not self.hs.config.worker_app
|
||||||
valid_until_ms = None
|
valid_until_ms = None
|
||||||
if self.session_lifetime is not None:
|
if self.session_lifetime is not None:
|
||||||
if is_guest:
|
if is_guest:
|
||||||
|
|
@ -694,7 +739,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (registered_device_id, access_token)
|
return {"device_id": registered_device_id, "access_token": access_token}
|
||||||
|
|
||||||
async def post_registration_actions(
|
async def post_registration_actions(
|
||||||
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
# succession, only process the first attempt and return its result to
|
# succession, only process the first attempt and return its result to
|
||||||
# subsequent requests
|
# subsequent requests
|
||||||
self._upgrade_response_cache = ResponseCache(
|
self._upgrade_response_cache = ResponseCache(
|
||||||
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
) # type: ResponseCache[Tuple[str, str]]
|
||||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.enable_room_list_search = hs.config.enable_room_list_search
|
self.enable_room_list_search = hs.config.enable_room_list_search
|
||||||
self.response_cache = ResponseCache(
|
self.response_cache = ResponseCache(
|
||||||
hs, "room_list"
|
hs.get_clock(), "room_list"
|
||||||
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
|
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
|
||||||
self.remote_response_cache = ResponseCache(
|
self.remote_response_cache = ResponseCache(
|
||||||
hs, "remote_room_list", timeout_ms=30 * 1000
|
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
|
||||||
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
|
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
|
||||||
|
|
||||||
async def get_local_public_room_list(
|
async def get_local_public_room_list(
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
|
||||||
# the SsoIdentityProvider protocol type.
|
# the SsoIdentityProvider protocol type.
|
||||||
self.idp_icon = None
|
self.idp_icon = None
|
||||||
self.idp_brand = None
|
self.idp_brand = None
|
||||||
|
self.unstable_idp_brand = None
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
|
||||||
"""Optional branding identifier"""
|
"""Optional branding identifier"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unstable_idp_brand(self) -> Optional[str]:
|
||||||
|
"""Optional brand identifier for the unstable API (see MSC2858)."""
|
||||||
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self,
|
self,
|
||||||
|
|
@ -456,6 +461,7 @@ class SsoHandler:
|
||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
user_id,
|
user_id,
|
||||||
|
auth_provider_id,
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
extra_login_attributes,
|
extra_login_attributes,
|
||||||
|
|
@ -605,6 +611,7 @@ class SsoHandler:
|
||||||
default_display_name=attributes.display_name,
|
default_display_name=attributes.display_name,
|
||||||
bind_emails=attributes.emails,
|
bind_emails=attributes.emails,
|
||||||
user_agent_ips=[(user_agent, ip_address)],
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
|
|
@ -886,6 +893,7 @@ class SsoHandler:
|
||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
user_id,
|
user_id,
|
||||||
|
session.auth_provider_id,
|
||||||
request,
|
request,
|
||||||
session.client_redirect_url,
|
session.client_redirect_url,
|
||||||
session.extra_login_attributes,
|
session.extra_login_attributes,
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ class SyncHandler:
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.response_cache = ResponseCache(
|
self.response_cache = ResponseCache(
|
||||||
hs, "sync"
|
hs.get_clock(), "sync"
|
||||||
) # type: ResponseCache[Tuple[Any, ...]]
|
) # type: ResponseCache[Tuple[Any, ...]]
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,15 @@ from zope.interface import implementer, provider
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import VERIFY_NONE
|
from OpenSSL.SSL import VERIFY_NONE
|
||||||
from twisted.internet import defer, error as twisted_error, protocol, ssl
|
from twisted.internet import defer, error as twisted_error, protocol, ssl
|
||||||
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IAddress,
|
IAddress,
|
||||||
IHostResolution,
|
IHostResolution,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IResolutionReceiver,
|
IResolutionReceiver,
|
||||||
|
ITCPTransport,
|
||||||
)
|
)
|
||||||
|
from twisted.internet.protocol import connectionDone
|
||||||
from twisted.internet.task import Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
@ -56,13 +59,20 @@ from twisted.web.client import (
|
||||||
)
|
)
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
from twisted.web.iweb import (
|
||||||
|
UNKNOWN_LENGTH,
|
||||||
|
IAgent,
|
||||||
|
IBodyProducer,
|
||||||
|
IPolicyForHTTPS,
|
||||||
|
IResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
||||||
from synapse.http.proxyagent import ProxyAgent
|
from synapse.http.proxyagent import ProxyAgent
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
|
|
||||||
|
|
@ -150,16 +160,17 @@ class _IPBlacklistingResolver:
|
||||||
def resolveHostName(
|
def resolveHostName(
|
||||||
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
|
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
|
||||||
) -> IResolutionReceiver:
|
) -> IResolutionReceiver:
|
||||||
|
|
||||||
r = recv()
|
|
||||||
addresses = [] # type: List[IAddress]
|
addresses = [] # type: List[IAddress]
|
||||||
|
|
||||||
def _callback() -> None:
|
def _callback() -> None:
|
||||||
r.resolutionBegan(None)
|
|
||||||
|
|
||||||
has_bad_ip = False
|
has_bad_ip = False
|
||||||
for i in addresses:
|
for address in addresses:
|
||||||
ip_address = IPAddress(i.host)
|
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
|
||||||
|
# should go through this path.
|
||||||
|
if not isinstance(address, (IPv4Address, IPv6Address)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
ip_address = IPAddress(address.host)
|
||||||
|
|
||||||
if check_against_blacklist(
|
if check_against_blacklist(
|
||||||
ip_address, self._ip_whitelist, self._ip_blacklist
|
ip_address, self._ip_whitelist, self._ip_blacklist
|
||||||
|
|
@ -174,15 +185,15 @@ class _IPBlacklistingResolver:
|
||||||
# request, but all we can really do from here is claim that there were no
|
# request, but all we can really do from here is claim that there were no
|
||||||
# valid results.
|
# valid results.
|
||||||
if not has_bad_ip:
|
if not has_bad_ip:
|
||||||
for i in addresses:
|
for address in addresses:
|
||||||
r.addressResolved(i)
|
recv.addressResolved(address)
|
||||||
r.resolutionComplete()
|
recv.resolutionComplete()
|
||||||
|
|
||||||
@provider(IResolutionReceiver)
|
@provider(IResolutionReceiver)
|
||||||
class EndpointReceiver:
|
class EndpointReceiver:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
|
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
|
||||||
pass
|
recv.resolutionBegan(resolutionInProgress)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def addressResolved(address: IAddress) -> None:
|
def addressResolved(address: IAddress) -> None:
|
||||||
|
|
@ -196,10 +207,10 @@ class _IPBlacklistingResolver:
|
||||||
EndpointReceiver, hostname, portNumber=portNumber
|
EndpointReceiver, hostname, portNumber=portNumber
|
||||||
)
|
)
|
||||||
|
|
||||||
return r
|
return recv
|
||||||
|
|
||||||
|
|
||||||
@implementer(IReactorPluggableNameResolver)
|
@implementer(ISynapseReactor)
|
||||||
class BlacklistingReactorWrapper:
|
class BlacklistingReactorWrapper:
|
||||||
"""
|
"""
|
||||||
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
|
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
|
||||||
|
|
@ -325,7 +336,7 @@ class SimpleHttpClient:
|
||||||
# filters out blacklisted IP addresses, to prevent DNS rebinding.
|
# filters out blacklisted IP addresses, to prevent DNS rebinding.
|
||||||
self.reactor = BlacklistingReactorWrapper(
|
self.reactor = BlacklistingReactorWrapper(
|
||||||
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
|
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
|
||||||
)
|
) # type: ISynapseReactor
|
||||||
else:
|
else:
|
||||||
self.reactor = hs.get_reactor()
|
self.reactor = hs.get_reactor()
|
||||||
|
|
||||||
|
|
@ -346,7 +357,7 @@ class SimpleHttpClient:
|
||||||
contextFactory=self.hs.get_http_client_context_factory(),
|
contextFactory=self.hs.get_http_client_context_factory(),
|
||||||
pool=pool,
|
pool=pool,
|
||||||
use_proxy=use_proxy,
|
use_proxy=use_proxy,
|
||||||
)
|
) # type: IAgent
|
||||||
|
|
||||||
if self._ip_blacklist:
|
if self._ip_blacklist:
|
||||||
# If we have an IP blacklist, we then install the blacklisting Agent
|
# If we have an IP blacklist, we then install the blacklisting Agent
|
||||||
|
|
@ -752,6 +763,8 @@ class BodyExceededMaxSize(Exception):
|
||||||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
"""A protocol which immediately errors upon receiving data."""
|
"""A protocol which immediately errors upon receiving data."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[ITCPTransport]
|
||||||
|
|
||||||
def __init__(self, deferred: defer.Deferred):
|
def __init__(self, deferred: defer.Deferred):
|
||||||
self.deferred = deferred
|
self.deferred = deferred
|
||||||
|
|
||||||
|
|
@ -763,18 +776,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self.deferred.errback(BodyExceededMaxSize())
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
# Close the connection (forcefully) since all the data will get
|
# Close the connection (forcefully) since all the data will get
|
||||||
# discarded anyway.
|
# discarded anyway.
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def dataReceived(self, data: bytes) -> None:
|
def dataReceived(self, data: bytes) -> None:
|
||||||
self._maybe_fail()
|
self._maybe_fail()
|
||||||
|
|
||||||
def connectionLost(self, reason: Failure) -> None:
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
self._maybe_fail()
|
self._maybe_fail()
|
||||||
|
|
||||||
|
|
||||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[ITCPTransport]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
||||||
):
|
):
|
||||||
|
|
@ -797,9 +813,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self.deferred.errback(BodyExceededMaxSize())
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
# Close the connection (forcefully) since all the data will get
|
# Close the connection (forcefully) since all the data will get
|
||||||
# discarded anyway.
|
# discarded anyway.
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def connectionLost(self, reason: Failure) -> None:
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
# If the maximum size was already exceeded, there's nothing to do.
|
# If the maximum size was already exceeded, there's nothing to do.
|
||||||
if self.deferred.called:
|
if self.deferred.called:
|
||||||
return
|
return
|
||||||
|
|
@ -868,6 +885,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
|
||||||
return query_str.encode("utf8")
|
return query_str.encode("utf8")
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IPolicyForHTTPS)
|
||||||
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||||
"""
|
"""
|
||||||
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
|
||||||
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
||||||
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -68,7 +69,7 @@ class MatrixFederationAgent:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reactor: IReactorCore,
|
reactor: ISynapseReactor,
|
||||||
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
|
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
|
||||||
user_agent: bytes,
|
user_agent: bytes,
|
||||||
ip_blacklist: IPSet,
|
ip_blacklist: IPSet,
|
||||||
|
|
|
||||||
|
|
@ -322,7 +322,8 @@ def _cache_period_from_headers(
|
||||||
|
|
||||||
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
||||||
cache_controls = {}
|
cache_controls = {}
|
||||||
for hdr in headers.getRawHeaders(b"cache-control", []):
|
cache_control_headers = headers.getRawHeaders(b"cache-control") or []
|
||||||
|
for hdr in cache_control_headers:
|
||||||
for directive in hdr.split(b","):
|
for directive in hdr.split(b","):
|
||||||
splits = [x.strip() for x in directive.split(b"=", 1)]
|
splits = [x.strip() for x in directive.split(b"=", 1)]
|
||||||
k = splits[0].lower()
|
k = splits[0].lower()
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
|
||||||
start_active_span,
|
start_active_span,
|
||||||
tags,
|
tags,
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict
|
from synapse.types import ISynapseReactor, JsonDict
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
|
||||||
# addresses, to prevent DNS rebinding.
|
# addresses, to prevent DNS rebinding.
|
||||||
self.reactor = BlacklistingReactorWrapper(
|
self.reactor = BlacklistingReactorWrapper(
|
||||||
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
|
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
|
||||||
)
|
) # type: ISynapseReactor
|
||||||
|
|
||||||
user_agent = hs.version_string
|
user_agent = hs.version_string
|
||||||
if hs.config.user_agent_suffix:
|
if hs.config.user_agent_suffix:
|
||||||
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
|
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
|
||||||
user_agent = user_agent.encode("ascii")
|
user_agent = user_agent.encode("ascii")
|
||||||
|
|
||||||
self.agent = MatrixFederationAgent(
|
federation_agent = MatrixFederationAgent(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
tls_client_options_factory,
|
tls_client_options_factory,
|
||||||
user_agent,
|
user_agent,
|
||||||
|
|
@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
|
||||||
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
|
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
|
||||||
# blacklist via IP literals in server names
|
# blacklist via IP literals in server names
|
||||||
self.agent = BlacklistingAgentWrapper(
|
self.agent = BlacklistingAgentWrapper(
|
||||||
self.agent,
|
federation_agent,
|
||||||
ip_blacklist=hs.config.federation_ip_range_blacklist,
|
ip_blacklist=hs.config.federation_ip_range_blacklist,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -534,9 +534,10 @@ class MatrixFederationHttpClient:
|
||||||
response.code, response_phrase, body
|
response.code, response_phrase, body
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retry if the error is a 429 (Too Many Requests),
|
# Retry if the error is a 5xx or a 429 (Too Many
|
||||||
# otherwise just raise a standard HttpResponseException
|
# Requests), otherwise just raise a standard
|
||||||
if response.code == 429:
|
# `HttpResponseException`
|
||||||
|
if 500 <= response.code < 600 or response.code == 429:
|
||||||
raise RequestSendFailed(exc, can_retry=True) from exc
|
raise RequestSendFailed(exc, can_retry=True) from exc
|
||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
|
||||||
TCP4ClientEndpoint,
|
TCP4ClientEndpoint,
|
||||||
TCP6ClientEndpoint,
|
TCP6ClientEndpoint,
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
|
from twisted.internet.tcp import Connection
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -52,7 +53,9 @@ class LogProducer:
|
||||||
format: A callable to format the log record to a string.
|
format: A callable to format the log record to a string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transport = attr.ib(type=ITransport)
|
# This is essentially ITCPTransport, but that is missing certain fields
|
||||||
|
# (connected and registerProducer) which are part of the implementation.
|
||||||
|
transport = attr.ib(type=Connection)
|
||||||
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||||
_buffer = attr.ib(type=deque)
|
_buffer = attr.ib(type=deque)
|
||||||
_paused = attr.ib(default=False, type=bool, init=False)
|
_paused = attr.ib(default=False, type=bool, init=False)
|
||||||
|
|
@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
|
||||||
if self._connection_waiter:
|
if self._connection_waiter:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
|
||||||
|
|
||||||
def fail(failure: Failure) -> None:
|
def fail(failure: Failure) -> None:
|
||||||
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||||
# reconnect (this will cause an infinite loop of errors).
|
# reconnect (this will cause an infinite loop of errors).
|
||||||
|
|
@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
def writer(result: Protocol) -> None:
|
def writer(result: Protocol) -> None:
|
||||||
|
# Force recognising transport as a Connection and not the more
|
||||||
|
# generic ITransport.
|
||||||
|
transport = result.transport # type: Connection # type: ignore
|
||||||
|
|
||||||
# We have a connection. If we already have a producer, and its
|
# We have a connection. If we already have a producer, and its
|
||||||
# transport is the same, just trigger a resumeProducing.
|
# transport is the same, just trigger a resumeProducing.
|
||||||
if self._producer and result.transport is self._producer.transport:
|
if self._producer and transport is self._producer.transport:
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
return
|
return
|
||||||
|
|
@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
|
||||||
# Make a new producer and start it.
|
# Make a new producer and start it.
|
||||||
self._producer = LogProducer(
|
self._producer = LogProducer(
|
||||||
buffer=self._buffer,
|
buffer=self._buffer,
|
||||||
transport=result.transport,
|
transport=transport,
|
||||||
format=self.format,
|
format=self.format,
|
||||||
)
|
)
|
||||||
result.transport.registerProducer(self._producer, True)
|
transport.registerProducer(self._producer, True)
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
|
|
||||||
self._connection_waiter.addCallbacks(writer, fail)
|
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
|
||||||
|
deferred.addCallbacks(writer, fail)
|
||||||
|
self._connection_waiter = deferred
|
||||||
|
|
||||||
def _handle_pressure(self) -> None:
|
def _handle_pressure(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -669,7 +669,7 @@ def preserve_fn(f):
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def run_in_background(f, *args, **kwargs):
|
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
||||||
"""Calls a function, ensuring that the current context is restored after
|
"""Calls a function, ensuring that the current context is restored after
|
||||||
return from the function, and that the sentinel context is set once the
|
return from the function, and that the sentinel context is set once the
|
||||||
deferred returned by the function completes.
|
deferred returned by the function completes.
|
||||||
|
|
@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
|
||||||
if isinstance(res, types.CoroutineType):
|
if isinstance(res, types.CoroutineType):
|
||||||
res = defer.ensureDeferred(res)
|
res = defer.ensureDeferred(res)
|
||||||
|
|
||||||
|
# At this point we should have a Deferred, if not then f was a synchronous
|
||||||
|
# function, wrap it in a Deferred for consistency.
|
||||||
if not isinstance(res, defer.Deferred):
|
if not isinstance(res, defer.Deferred):
|
||||||
return res
|
return defer.succeed(res)
|
||||||
|
|
||||||
if res.called and not res.paused:
|
if res.called and not res.paused:
|
||||||
# The function should have maintained the logcontext, so we can
|
# The function should have maintained the logcontext, so we can
|
||||||
|
|
|
||||||
|
|
@ -203,11 +203,26 @@ class ModuleApi:
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
def generate_short_term_login_token(
|
||||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
self,
|
||||||
|
user_id: str,
|
||||||
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
|
auth_provider_id: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a login token suitable for m.login.token authentication"""
|
"""Generate a login token suitable for m.login.token authentication
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: gives the ID of the user that the token is for
|
||||||
|
|
||||||
|
duration_in_ms: the time that the token will be valid for
|
||||||
|
|
||||||
|
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
|
||||||
|
to get this token, if any. This is encoded in the token so that
|
||||||
|
/login can report stats on number of successful logins by IdP.
|
||||||
|
"""
|
||||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
||||||
user_id, duration_in_ms
|
user_id,
|
||||||
|
auth_provider_id,
|
||||||
|
duration_in_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
@ -276,6 +291,7 @@ class ModuleApi:
|
||||||
"""
|
"""
|
||||||
self._auth_handler._complete_sso_login(
|
self._auth_handler._complete_sso_login(
|
||||||
registered_user_id,
|
registered_user_id,
|
||||||
|
"<unknown>",
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
)
|
)
|
||||||
|
|
@ -286,6 +302,7 @@ class ModuleApi:
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
new_user: bool = False,
|
new_user: bool = False,
|
||||||
|
auth_provider_id: str = "<unknown>",
|
||||||
):
|
):
|
||||||
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
||||||
want their access token sent to `client_redirect_url`, or redirect them to that
|
want their access token sent to `client_redirect_url`, or redirect them to that
|
||||||
|
|
@ -299,9 +316,15 @@ class ModuleApi:
|
||||||
redirect them directly if whitelisted).
|
redirect them directly if whitelisted).
|
||||||
new_user: set to true to use wording for the consent appropriate to a user
|
new_user: set to true to use wording for the consent appropriate to a user
|
||||||
who has just registered.
|
who has just registered.
|
||||||
|
auth_provider_id: the ID of the SSO IdP which was used to log in. This
|
||||||
|
is used to track counts of sucessful logins by IdP.
|
||||||
"""
|
"""
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
registered_user_id, request, client_redirect_url, new_user=new_user
|
registered_user_id,
|
||||||
|
auth_provider_id,
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
|
new_user=new_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.internet.base import DelayedCall
|
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
||||||
|
|
@ -66,7 +66,7 @@ class EmailPusher(Pusher):
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.email = pusher_config.pushkey
|
self.email = pusher_config.pushkey
|
||||||
self.timed_call = None # type: Optional[DelayedCall]
|
self.timed_call = None # type: Optional[IDelayedCall]
|
||||||
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
||||||
self._inited = False
|
self._inited = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import urllib
|
import urllib
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
|
|
@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_pending_outgoing_requests = Gauge(
|
_pending_outgoing_requests = Gauge(
|
||||||
|
|
@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
CACHE = True
|
CACHE = True
|
||||||
RETRY_ON_TIMEOUT = True
|
RETRY_ON_TIMEOUT = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
if self.CACHE:
|
if self.CACHE:
|
||||||
self.response_cache = ResponseCache(
|
self.response_cache = ResponseCache(
|
||||||
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||||
) # type: ResponseCache[str]
|
) # type: ResponseCache[str]
|
||||||
|
|
||||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
is_appservice_ghost = content["is_appservice_ghost"]
|
is_appservice_ghost = content["is_appservice_ghost"]
|
||||||
|
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
res = await self.registration_handler.register_device_inner(
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
initial_display_name,
|
initial_display_name,
|
||||||
|
|
@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {"device_id": device_id, "access_token": access_token}
|
return 200, res
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.protocol import AbstractConnection
|
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||||
from synapse.replication.tcp.streams import (
|
from synapse.replication.tcp.streams import (
|
||||||
STREAMS_MAP,
|
STREAMS_MAP,
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
|
|
@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
|
||||||
|
|
||||||
# the type of the entries in _command_queues_by_stream
|
# the type of the entries in _command_queues_by_stream
|
||||||
_StreamCommandQueue = Deque[
|
_StreamCommandQueue = Deque[
|
||||||
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
|
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -174,7 +174,7 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# The currently connected connections. (The list of places we need to send
|
# The currently connected connections. (The list of places we need to send
|
||||||
# outgoing replication commands to.)
|
# outgoing replication commands to.)
|
||||||
self._connections = [] # type: List[AbstractConnection]
|
self._connections = [] # type: List[IReplicationConnection]
|
||||||
|
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_replication_tcp_resource_total_connections",
|
"synapse_replication_tcp_resource_total_connections",
|
||||||
|
|
@ -197,7 +197,7 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# For each connection, the incoming stream names that have received a POSITION
|
# For each connection, the incoming stream names that have received a POSITION
|
||||||
# from that connection.
|
# from that connection.
|
||||||
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
|
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
|
||||||
|
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_replication_tcp_command_queue",
|
"synapse_replication_tcp_command_queue",
|
||||||
|
|
@ -220,7 +220,7 @@ class ReplicationCommandHandler:
|
||||||
self._server_notices_sender = hs.get_server_notices_sender()
|
self._server_notices_sender = hs.get_server_notices_sender()
|
||||||
|
|
||||||
def _add_command_to_stream_queue(
|
def _add_command_to_stream_queue(
|
||||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Queue the given received command for processing
|
"""Queue the given received command for processing
|
||||||
|
|
||||||
|
|
@ -267,7 +267,7 @@ class ReplicationCommandHandler:
|
||||||
async def _process_command(
|
async def _process_command(
|
||||||
self,
|
self,
|
||||||
cmd: Union[PositionCommand, RdataCommand],
|
cmd: Union[PositionCommand, RdataCommand],
|
||||||
conn: AbstractConnection,
|
conn: IReplicationConnection,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(cmd, PositionCommand):
|
if isinstance(cmd, PositionCommand):
|
||||||
|
|
@ -302,7 +302,7 @@ class ReplicationCommandHandler:
|
||||||
hs, outbound_redis_connection
|
hs, outbound_redis_connection
|
||||||
)
|
)
|
||||||
hs.get_reactor().connectTCP(
|
hs.get_reactor().connectTCP(
|
||||||
hs.config.redis.redis_host,
|
hs.config.redis.redis_host.encode(),
|
||||||
hs.config.redis.redis_port,
|
hs.config.redis.redis_port,
|
||||||
self._factory,
|
self._factory,
|
||||||
)
|
)
|
||||||
|
|
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
|
||||||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
||||||
host = hs.config.worker_replication_host
|
host = hs.config.worker_replication_host
|
||||||
port = hs.config.worker_replication_port
|
port = hs.config.worker_replication_port
|
||||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
|
||||||
|
|
||||||
def get_streams(self) -> Dict[str, Stream]:
|
def get_streams(self) -> Dict[str, Stream]:
|
||||||
"""Get a map from stream name to all streams."""
|
"""Get a map from stream name to all streams."""
|
||||||
|
|
@ -321,10 +321,10 @@ class ReplicationCommandHandler:
|
||||||
"""Get a list of streams that this instances replicates."""
|
"""Get a list of streams that this instances replicates."""
|
||||||
return self._streams_to_replicate
|
return self._streams_to_replicate
|
||||||
|
|
||||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
||||||
self.send_positions_to_connection(conn)
|
self.send_positions_to_connection(conn)
|
||||||
|
|
||||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
def send_positions_to_connection(self, conn: IReplicationConnection):
|
||||||
"""Send current position of all streams this process is source of to
|
"""Send current position of all streams this process is source of to
|
||||||
the connection.
|
the connection.
|
||||||
"""
|
"""
|
||||||
|
|
@ -347,7 +347,7 @@ class ReplicationCommandHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_USER_SYNC(
|
def on_USER_SYNC(
|
||||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
self, conn: IReplicationConnection, cmd: UserSyncCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
user_sync_counter.inc()
|
user_sync_counter.inc()
|
||||||
|
|
||||||
|
|
@ -359,21 +359,23 @@ class ReplicationCommandHandler:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_CLEAR_USER_SYNC(
|
def on_CLEAR_USER_SYNC(
|
||||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
if self._is_master:
|
if self._is_master:
|
||||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
def on_FEDERATION_ACK(
|
||||||
|
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||||
|
):
|
||||||
federation_ack_counter.inc()
|
federation_ack_counter.inc()
|
||||||
|
|
||||||
if self._federation_sender:
|
if self._federation_sender:
|
||||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||||
|
|
||||||
def on_USER_IP(
|
def on_USER_IP(
|
||||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
self, conn: IReplicationConnection, cmd: UserIpCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
user_ip_cache_counter.inc()
|
user_ip_cache_counter.inc()
|
||||||
|
|
||||||
|
|
@ -395,7 +397,7 @@ class ReplicationCommandHandler:
|
||||||
assert self._server_notices_sender is not None
|
assert self._server_notices_sender is not None
|
||||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||||
|
|
||||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore RDATA that are just our own echoes
|
# Ignore RDATA that are just our own echoes
|
||||||
return
|
return
|
||||||
|
|
@ -412,7 +414,7 @@ class ReplicationCommandHandler:
|
||||||
self._add_command_to_stream_queue(conn, cmd)
|
self._add_command_to_stream_queue(conn, cmd)
|
||||||
|
|
||||||
async def _process_rdata(
|
async def _process_rdata(
|
||||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process an RDATA command
|
"""Process an RDATA command
|
||||||
|
|
||||||
|
|
@ -486,7 +488,7 @@ class ReplicationCommandHandler:
|
||||||
stream_name, instance_name, token, rows
|
stream_name, instance_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore POSITION that are just our own echoes
|
# Ignore POSITION that are just our own echoes
|
||||||
return
|
return
|
||||||
|
|
@ -496,7 +498,7 @@ class ReplicationCommandHandler:
|
||||||
self._add_command_to_stream_queue(conn, cmd)
|
self._add_command_to_stream_queue(conn, cmd)
|
||||||
|
|
||||||
async def _process_position(
|
async def _process_position(
|
||||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a POSITION command
|
"""Process a POSITION command
|
||||||
|
|
||||||
|
|
@ -553,7 +555,9 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||||
|
|
||||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
def on_REMOTE_SERVER_UP(
|
||||||
|
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||||
|
):
|
||||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||||
|
|
||||||
|
|
@ -576,7 +580,7 @@ class ReplicationCommandHandler:
|
||||||
# between two instances, but that is not currently supported).
|
# between two instances, but that is not currently supported).
|
||||||
self.send_command(cmd, ignore_conn=conn)
|
self.send_command(cmd, ignore_conn=conn)
|
||||||
|
|
||||||
def new_connection(self, connection: AbstractConnection):
|
def new_connection(self, connection: IReplicationConnection):
|
||||||
"""Called when we have a new connection."""
|
"""Called when we have a new connection."""
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
|
||||||
|
|
@ -603,7 +607,7 @@ class ReplicationCommandHandler:
|
||||||
UserSyncCommand(self._instance_id, user_id, True, now)
|
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||||
)
|
)
|
||||||
|
|
||||||
def lost_connection(self, connection: AbstractConnection):
|
def lost_connection(self, connection: IReplicationConnection):
|
||||||
"""Called when a connection is closed/lost."""
|
"""Called when a connection is closed/lost."""
|
||||||
# we no longer need _streams_by_connection for this connection.
|
# we no longer need _streams_by_connection for this connection.
|
||||||
streams = self._streams_by_connection.pop(connection, None)
|
streams = self._streams_by_connection.pop(connection, None)
|
||||||
|
|
@ -624,7 +628,7 @@ class ReplicationCommandHandler:
|
||||||
return bool(self._connections)
|
return bool(self._connections)
|
||||||
|
|
||||||
def send_command(
|
def send_command(
|
||||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||||
):
|
):
|
||||||
"""Send a command to all connected connections.
|
"""Send a command to all connected connections.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
|
||||||
> ERROR server stopping
|
> ERROR server stopping
|
||||||
* connection closed by server *
|
* connection closed by server *
|
||||||
"""
|
"""
|
||||||
import abc
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
|
@ -54,8 +53,10 @@ from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
from zope.interface import Interface, implementer
|
||||||
|
|
||||||
from twisted.internet import task
|
from twisted.internet import task
|
||||||
|
from twisted.internet.tcp import Connection
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
|
@ -121,6 +122,14 @@ class ConnectionStates:
|
||||||
CLOSED = "closed"
|
CLOSED = "closed"
|
||||||
|
|
||||||
|
|
||||||
|
class IReplicationConnection(Interface):
|
||||||
|
"""An interface for replication connections."""
|
||||||
|
|
||||||
|
def send_command(cmd: Command):
|
||||||
|
"""Send the command down the connection"""
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IReplicationConnection)
|
||||||
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
"""Base replication protocol shared between client and server.
|
"""Base replication protocol shared between client and server.
|
||||||
|
|
||||||
|
|
@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
(if they send a `PING` command)
|
(if they send a `PING` command)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The transport is going to be an ITCPTransport, but that doesn't have the
|
||||||
|
# (un)registerProducer methods, those are only on the implementation.
|
||||||
|
transport = None # type: Connection
|
||||||
|
|
||||||
delimiter = b"\n"
|
delimiter = b"\n"
|
||||||
|
|
||||||
# Valid commands we expect to receive
|
# Valid commands we expect to receive
|
||||||
|
|
@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
|
|
||||||
connected_connections.append(self) # Register connection for metrics
|
connected_connections.append(self) # Register connection for metrics
|
||||||
|
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.registerProducer(self, True) # For the *Producing callbacks
|
self.transport.registerProducer(self, True) # For the *Producing callbacks
|
||||||
|
|
||||||
self._send_pending_commands()
|
self._send_pending_commands()
|
||||||
|
|
@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s] Failed to close connection gracefully, aborting", self.id()
|
"[%s] Failed to close connection gracefully, aborting", self.id()
|
||||||
)
|
)
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
else:
|
else:
|
||||||
if now - self.last_sent_command >= PING_TIME:
|
if now - self.last_sent_command >= PING_TIME:
|
||||||
|
|
@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
def close(self):
|
def close(self):
|
||||||
logger.warning("[%s] Closing connection", self.id())
|
logger.warning("[%s] Closing connection", self.id())
|
||||||
self.time_we_closed = self.clock.time_msec()
|
self.time_we_closed = self.clock.time_msec()
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
|
|
@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
||||||
if isinstance(reason, Failure):
|
if isinstance(reason, Failure):
|
||||||
|
assert reason.type is not None
|
||||||
connection_close_counter.labels(reason.type.__name__).inc()
|
connection_close_counter.labels(reason.type.__name__).inc()
|
||||||
else:
|
else:
|
||||||
connection_close_counter.labels(reason.__class__.__name__).inc()
|
connection_close_counter.labels(reason.__class__.__name__).inc()
|
||||||
|
|
@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.send_command(ReplicateCommand())
|
self.send_command(ReplicateCommand())
|
||||||
|
|
||||||
|
|
||||||
class AbstractConnection(abc.ABC):
|
|
||||||
"""An interface for replication connections."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def send_command(self, cmd: Command):
|
|
||||||
"""Send the command down the connection"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This tells python that `BaseReplicationStreamProtocol` implements the
|
|
||||||
# interface.
|
|
||||||
AbstractConnection.register(BaseReplicationStreamProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
# The following simply registers metrics for the replication connections
|
# The following simply registers metrics for the replication connections
|
||||||
|
|
||||||
pending_commands = LaterGauge(
|
pending_commands = LaterGauge(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import txredisapi
|
import txredisapi
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
|
from twisted.internet.interfaces import IAddress, IConnector
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
|
@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
|
||||||
parse_command_from_line,
|
parse_command_from_line,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.protocol import (
|
from synapse.replication.tcp.protocol import (
|
||||||
AbstractConnection,
|
IReplicationConnection,
|
||||||
tcp_inbound_commands_counter,
|
tcp_inbound_commands_counter,
|
||||||
tcp_outbound_commands_counter,
|
tcp_outbound_commands_counter,
|
||||||
)
|
)
|
||||||
|
|
@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
@implementer(IReplicationConnection)
|
||||||
|
class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||||
"""Connection to redis subscribed to replication stream.
|
"""Connection to redis subscribed to replication stream.
|
||||||
|
|
||||||
This class fulfils two functions:
|
This class fulfils two functions:
|
||||||
|
|
@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
connection, parsing *incoming* messages into replication commands, and passing them
|
connection, parsing *incoming* messages into replication commands, and passing them
|
||||||
to `ReplicationCommandHandler`
|
to `ReplicationCommandHandler`
|
||||||
|
|
||||||
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
|
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
|
||||||
onto outbound_redis_connection.
|
onto outbound_redis_connection.
|
||||||
|
|
||||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||||
|
|
@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to send ping to a redis connection")
|
logger.warning("Failed to send ping to a redis connection")
|
||||||
|
|
||||||
|
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
|
||||||
|
# it's rubbish. We add our own here.
|
||||||
|
|
||||||
|
def startedConnecting(self, connector: IConnector):
|
||||||
|
logger.info(
|
||||||
|
"Connecting to redis server %s", format_address(connector.getDestination())
|
||||||
|
)
|
||||||
|
super().startedConnecting(connector)
|
||||||
|
|
||||||
|
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
|
||||||
|
logger.info(
|
||||||
|
"Connection to redis server %s failed: %s",
|
||||||
|
format_address(connector.getDestination()),
|
||||||
|
reason.value,
|
||||||
|
)
|
||||||
|
super().clientConnectionFailed(connector, reason)
|
||||||
|
|
||||||
|
def clientConnectionLost(self, connector: IConnector, reason: Failure):
|
||||||
|
logger.info(
|
||||||
|
"Connection to redis server %s lost: %s",
|
||||||
|
format_address(connector.getDestination()),
|
||||||
|
reason.value,
|
||||||
|
)
|
||||||
|
super().clientConnectionLost(connector, reason)
|
||||||
|
|
||||||
|
|
||||||
|
def format_address(address: IAddress) -> str:
|
||||||
|
if isinstance(address, (IPv4Address, IPv6Address)):
|
||||||
|
return "%s:%i" % (address.host, address.port)
|
||||||
|
return str(address)
|
||||||
|
|
||||||
|
|
||||||
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||||
"""This is a reconnecting factory that connects to redis and immediately
|
"""This is a reconnecting factory that connects to redis and immediately
|
||||||
|
|
@ -328,6 +365,6 @@ def lazyConnection(
|
||||||
factory.continueTrying = reconnect
|
factory.continueTrying = reconnect
|
||||||
|
|
||||||
reactor = hs.get_reactor()
|
reactor = hs.get_reactor()
|
||||||
reactor.connectTCP(host, port, factory, 30)
|
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
|
||||||
|
|
||||||
return factory.handler
|
return factory.handler
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,9 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import twisted.web.server
|
from synapse.api.auth import Auth
|
||||||
|
|
||||||
import synapse.api.auth
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
|
||||||
return patterns
|
return patterns
|
||||||
|
|
||||||
|
|
||||||
async def assert_requester_is_admin(
|
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
|
||||||
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
|
|
||||||
) -> None:
|
|
||||||
"""Verify that the requester is an admin user
|
"""Verify that the requester is an admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth: api.auth.Auth singleton
|
auth: Auth singleton
|
||||||
request: incoming request
|
request: incoming request
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
|
@ -53,11 +50,11 @@ async def assert_requester_is_admin(
|
||||||
await assert_user_is_admin(auth, requester.user)
|
await assert_user_is_admin(auth, requester.user)
|
||||||
|
|
||||||
|
|
||||||
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
|
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
|
||||||
"""Verify that the given user is an admin user
|
"""Verify that the given user is an admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth: api.auth.Auth singleton
|
auth: Auth singleton
|
||||||
user_id: user to check
|
user_id: user to check
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,9 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.admin._base import (
|
from synapse.rest.admin._base import (
|
||||||
admin_patterns,
|
admin_patterns,
|
||||||
assert_requester_is_admin,
|
assert_requester_is_admin,
|
||||||
|
|
@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
|
@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
|
@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: Request, server_name: str, media_id: str
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, media_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
|
@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
is_admin = await self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
|
|
@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
|
@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, server_name: str, media_id: str
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
|
@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, server_name: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.admin import assert_requester_is_admin
|
from synapse.rest.admin import assert_requester_is_admin
|
||||||
from synapse.rest.admin._base import admin_patterns
|
from synapse.rest.admin._base import admin_patterns
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class PurgeRoomServlet(RestServlet):
|
class PurgeRoomServlet(RestServlet):
|
||||||
|
|
@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = admin_patterns("/purge_room$")
|
PATTERNS = admin_patterns("/purge_room$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer): server
|
|
||||||
"""
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.pagination_handler = hs.get_pagination_handler()
|
self.pagination_handler = hs.get_pagination_handler()
|
||||||
|
|
||||||
async def on_POST(self, request):
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
|
||||||
|
|
@ -685,7 +685,10 @@ class RoomEventContextServlet(RestServlet):
|
||||||
results["events_after"], time_now
|
results["events_after"], time_now
|
||||||
)
|
)
|
||||||
results["state"] = await self._event_serializer.serialize_events(
|
results["state"] = await self._event_serializer.serialize_events(
|
||||||
results["state"], time_now
|
results["state"],
|
||||||
|
time_now,
|
||||||
|
# No need to bundle aggregations for state events
|
||||||
|
bundle_aggregations=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, results
|
return 200, results
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,24 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.admin import assert_requester_is_admin
|
from synapse.rest.admin import assert_requester_is_admin
|
||||||
from synapse.rest.admin._base import admin_patterns
|
from synapse.rest.admin._base import admin_patterns
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class SendServerNoticeServlet(RestServlet):
|
class SendServerNoticeServlet(RestServlet):
|
||||||
|
|
@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer): server
|
|
||||||
"""
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.txns = HttpTransactionCache(hs)
|
self.txns = HttpTransactionCache(hs)
|
||||||
self.snm = hs.get_server_notices_manager()
|
self.snm = hs.get_server_notices_manager()
|
||||||
|
|
||||||
def register(self, json_resource):
|
def register(self, json_resource: HttpServer):
|
||||||
PATTERN = "/send_server_notice"
|
PATTERN = "/send_server_notice"
|
||||||
json_resource.register_paths(
|
json_resource.register_paths(
|
||||||
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
|
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
|
||||||
|
|
@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_POST(self, request, txn_id=None):
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, txn_id: Optional[str] = None
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
assert_params_in_dict(body, ("user_id", "content"))
|
assert_params_in_dict(body, ("user_id", "content"))
|
||||||
|
|
@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
|
||||||
|
|
||||||
return 200, {"event_id": event.event_id}
|
return 200, {"event_id": event.event_id}
|
||||||
|
|
||||||
def on_PUT(self, request, txn_id):
|
def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
|
||||||
return self.txns.fetch_or_execute_request(
|
return self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, txn_id
|
request, self.on_POST, request, txn_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
|
||||||
target_user.to_string(), False, requester, by_admin=True
|
target_user.to_string(), False, requester, by_admin=True
|
||||||
)
|
)
|
||||||
elif not deactivate and user["deactivated"]:
|
elif not deactivate and user["deactivated"]:
|
||||||
if "password" not in body:
|
if (
|
||||||
|
"password" not in body
|
||||||
|
and self.hs.config.password_localdb_enabled
|
||||||
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Must provide a password to re-activate an account."
|
400, "Must provide a password to re-activate an account."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
from synapse.api.urls import CLIENT_API_PREFIX
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.sso import SsoIdentityProvider
|
from synapse.handlers.sso import SsoIdentityProvider
|
||||||
from synapse.http import get_request_uri
|
from synapse.http import get_request_uri
|
||||||
|
|
@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
|
||||||
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
||||||
|
|
||||||
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
||||||
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
|
sso_flow = {
|
||||||
|
"type": LoginRestServlet.SSO_TYPE,
|
||||||
|
"identity_providers": [
|
||||||
|
_get_auth_flow_dict_for_idp(
|
||||||
|
idp,
|
||||||
|
)
|
||||||
|
for idp in self._sso_handler.get_identity_providers().values()
|
||||||
|
],
|
||||||
|
} # type: JsonDict
|
||||||
|
|
||||||
if self._msc2858_enabled:
|
if self._msc2858_enabled:
|
||||||
|
# backwards-compatibility support for clients which don't
|
||||||
|
# support the stable API yet
|
||||||
sso_flow["org.matrix.msc2858.identity_providers"] = [
|
sso_flow["org.matrix.msc2858.identity_providers"] = [
|
||||||
_get_auth_flow_dict_for_idp(idp)
|
_get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
|
||||||
for idp in self._sso_handler.get_identity_providers().values()
|
for idp in self._sso_handler.get_identity_providers().values()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -219,6 +231,7 @@ class LoginRestServlet(RestServlet):
|
||||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||||
create_non_existent_users: bool = False,
|
create_non_existent_users: bool = False,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
|
|
@ -234,6 +247,8 @@ class LoginRestServlet(RestServlet):
|
||||||
create_non_existent_users: Whether to create the user if they don't
|
create_non_existent_users: Whether to create the user if they don't
|
||||||
exist. Defaults to False.
|
exist. Defaults to False.
|
||||||
ratelimit: Whether to ratelimit the login request.
|
ratelimit: Whether to ratelimit the login request.
|
||||||
|
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||||
|
prometheus metrics).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: Dictionary of account information after successful login.
|
result: Dictionary of account information after successful login.
|
||||||
|
|
@ -256,7 +271,7 @@ class LoginRestServlet(RestServlet):
|
||||||
device_id = login_submission.get("device_id")
|
device_id = login_submission.get("device_id")
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
device_id, access_token = await self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
|
@ -283,12 +298,13 @@ class LoginRestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
res = await auth_handler.validate_short_term_login_token(token)
|
||||||
token
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._complete_login(
|
return await self._complete_login(
|
||||||
user_id, login_submission, self.auth_handler._sso_login_callback
|
res.user_id,
|
||||||
|
login_submission,
|
||||||
|
self.auth_handler._sso_login_callback,
|
||||||
|
auth_provider_id=res.auth_provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||||
|
|
@ -327,22 +343,38 @@ class LoginRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
|
def _get_auth_flow_dict_for_idp(
|
||||||
|
idp: SsoIdentityProvider, use_unstable_brands: bool = False
|
||||||
|
) -> JsonDict:
|
||||||
"""Return an entry for the login flow dict
|
"""Return an entry for the login flow dict
|
||||||
|
|
||||||
Returns an entry suitable for inclusion in "identity_providers" in the
|
Returns an entry suitable for inclusion in "identity_providers" in the
|
||||||
response to GET /_matrix/client/r0/login
|
response to GET /_matrix/client/r0/login
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idp: the identity provider to describe
|
||||||
|
use_unstable_brands: whether we should use brand identifiers suitable
|
||||||
|
for the unstable API
|
||||||
"""
|
"""
|
||||||
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
||||||
if idp.idp_icon:
|
if idp.idp_icon:
|
||||||
e["icon"] = idp.idp_icon
|
e["icon"] = idp.idp_icon
|
||||||
if idp.idp_brand:
|
if idp.idp_brand:
|
||||||
e["brand"] = idp.idp_brand
|
e["brand"] = idp.idp_brand
|
||||||
|
# use the stable brand identifier if the unstable identifier isn't defined.
|
||||||
|
if use_unstable_brands and idp.unstable_idp_brand:
|
||||||
|
e["brand"] = idp.unstable_idp_brand
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
||||||
class SsoRedirectServlet(RestServlet):
|
class SsoRedirectServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
|
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
|
||||||
|
re.compile(
|
||||||
|
"^"
|
||||||
|
+ CLIENT_API_PREFIX
|
||||||
|
+ "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
# make sure that the relevant handlers are instantiated, so that they
|
# make sure that the relevant handlers are instantiated, so that they
|
||||||
|
|
@ -360,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
|
||||||
def register(self, http_server: HttpServer) -> None:
|
def register(self, http_server: HttpServer) -> None:
|
||||||
super().register(http_server)
|
super().register(http_server)
|
||||||
if self._msc2858_enabled:
|
if self._msc2858_enabled:
|
||||||
# expose additional endpoint for MSC2858 support
|
# expose additional endpoint for MSC2858 support: backwards-compat support
|
||||||
|
# for clients which don't yet support the stable endpoints.
|
||||||
http_server.register_paths(
|
http_server.register_paths(
|
||||||
"GET",
|
"GET",
|
||||||
client_patterns(
|
client_patterns(
|
||||||
|
|
|
||||||
|
|
@ -674,7 +674,10 @@ class RoomEventContextServlet(RestServlet):
|
||||||
results["events_after"], time_now
|
results["events_after"], time_now
|
||||||
)
|
)
|
||||||
results["state"] = await self._event_serializer.serialize_events(
|
results["state"] = await self._event_serializer.serialize_events(
|
||||||
results["state"], time_now
|
results["state"],
|
||||||
|
time_now,
|
||||||
|
# No need to bundle aggregations for state events
|
||||||
|
bundle_aggregations=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, results
|
return 200, results
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ from synapse.http.servlet import (
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import GroupID, JsonDict
|
from synapse.types import GroupID, JsonDict
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
|
||||||
return 200, group_description
|
return 200, group_description
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
group_id: str,
|
||||||
|
category_id: Optional[str],
|
||||||
|
room_id: str,
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
group_id: str,
|
||||||
|
role_id: Optional[str],
|
||||||
|
user_id: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, room_id: str, config_key: str
|
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id, user_id
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id, user_id
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||||
|
|
@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.app.homeserver import HomeServer
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
await self.auth.get_user_by_req(request)
|
await self.auth.get_user_by_req(request)
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ from synapse.api.errors import (
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.types import UserID
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
@ -145,7 +146,7 @@ class MediaRepository:
|
||||||
upload_name: Optional[str],
|
upload_name: Optional[str],
|
||||||
content: IO,
|
content: IO,
|
||||||
content_length: int,
|
content_length: int,
|
||||||
auth_user: str,
|
auth_user: UserID,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Store uploaded content for a local user and return the mxc URL
|
"""Store uploaded content for a local user and return the mxc URL
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from synapse.http.server import (
|
||||||
respond_with_json_bytes,
|
respond_with_json_bytes,
|
||||||
)
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||||
|
|
@ -174,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
|
||||||
|
|
@ -96,9 +96,14 @@ class Thumbnailer:
|
||||||
def _resize(self, width: int, height: int) -> Image:
|
def _resize(self, width: int, height: int) -> Image:
|
||||||
# 1-bit or 8-bit color palette images need converting to RGB
|
# 1-bit or 8-bit color palette images need converting to RGB
|
||||||
# otherwise they will be scaled using nearest neighbour which
|
# otherwise they will be scaled using nearest neighbour which
|
||||||
# looks awful
|
# looks awful.
|
||||||
if self.image.mode in ["1", "P"]:
|
#
|
||||||
self.image = self.image.convert("RGB")
|
# If the image has transparency, use RGBA instead.
|
||||||
|
if self.image.mode in ["1", "L", "P"]:
|
||||||
|
mode = "RGB"
|
||||||
|
if self.image.info.get("transparency", None) is not None:
|
||||||
|
mode = "RGBA"
|
||||||
|
self.image = self.image.convert(mode)
|
||||||
return self.image.resize((width, height), Image.ANTIALIAS)
|
return self.image.resize((width, height), Image.ANTIALIAS)
|
||||||
|
|
||||||
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from twisted.web.server import Request
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.media.v1.media_storage import SpamMediaException
|
from synapse.rest.media.v1.media_storage import SpamMediaException
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
|
||||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_POST(self, request: Request) -> None:
|
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
|
|
|
||||||
|
|
@ -14,24 +14,30 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.http.server import DirectServeHtmlResource
|
from synapse.http.server import DirectServeHtmlResource
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class SAML2ResponseResource(DirectServeHtmlResource):
|
class SAML2ResponseResource(DirectServeHtmlResource):
|
||||||
"""A Twisted web resource which handles the SAML response"""
|
"""A Twisted web resource which handles the SAML response"""
|
||||||
|
|
||||||
isLeaf = 1
|
isLeaf = 1
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._saml_handler = hs.get_saml_handler()
|
self._saml_handler = hs.get_saml_handler()
|
||||||
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
# We're not expecting any GET request on that resource if everything goes right,
|
# We're not expecting any GET request on that resource if everything goes right,
|
||||||
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
|
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
|
||||||
# In this case, just tell the user that something went wrong and they should
|
# In this case, just tell the user that something went wrong and they should
|
||||||
# try to authenticate again.
|
# try to authenticate again.
|
||||||
self._saml_handler._render_error(
|
self._sso_handler.render_error(
|
||||||
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
|
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import twisted.internet.base
|
|
||||||
import twisted.internet.tcp
|
import twisted.internet.tcp
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.mail.smtp import sendmail
|
from twisted.mail.smtp import sendmail
|
||||||
|
|
@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||||
from synapse.state import StateHandler, StateResolutionHandler
|
from synapse.state import StateHandler, StateResolutionHandler
|
||||||
from synapse.storage import Databases, DataStore, Storage
|
from synapse.storage import Databases, DataStore, Storage
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
from synapse.types import DomainSpecificString
|
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.distributor import Distributor
|
from synapse.util.distributor import Distributor
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
|
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
|
||||||
getattr(self, "get_" + i + "_handler")()
|
getattr(self, "get_" + i + "_handler")()
|
||||||
|
|
||||||
def get_reactor(self) -> twisted.internet.base.ReactorBase:
|
def get_reactor(self) -> ISynapseReactor:
|
||||||
"""
|
"""
|
||||||
Fetch the Twisted reactor in use by this HomeServer.
|
Fetch the Twisted reactor in use by this HomeServer.
|
||||||
"""
|
"""
|
||||||
|
|
@ -352,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
||||||
return (
|
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
|
||||||
InsecureInterceptableContextFactory()
|
return InsecureInterceptableContextFactory()
|
||||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
return RegularPolicyForHTTPS()
|
||||||
else RegularPolicyForHTTPS()
|
|
||||||
)
|
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_simple_http_client(self) -> SimpleHttpClient:
|
def get_simple_http_client(self) -> SimpleHttpClient:
|
||||||
|
|
|
||||||
|
|
@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
) # type: LruCache[str, List[Tuple[str, int]]]
|
) # type: LruCache[str, List[Tuple[str, int]]]
|
||||||
|
|
||||||
async def get_auth_chain(
|
async def get_auth_chain(
|
||||||
self, event_ids: Collection[str], include_given: bool = False
|
self, room_id: str, event_ids: Collection[str], include_given: bool = False
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
room_id: The room the event is in.
|
||||||
event_ids: state events
|
event_ids: state events
|
||||||
include_given: include the given events in result
|
include_given: include the given events in result
|
||||||
|
|
||||||
|
|
@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
list of events
|
list of events
|
||||||
"""
|
"""
|
||||||
event_ids = await self.get_auth_chain_ids(
|
event_ids = await self.get_auth_chain_ids(
|
||||||
event_ids, include_given=include_given
|
room_id, event_ids, include_given=include_given
|
||||||
)
|
)
|
||||||
return await self.get_events_as_list(event_ids)
|
return await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
async def get_auth_chain_ids(
|
async def get_auth_chain_ids(
|
||||||
self,
|
self,
|
||||||
|
room_id: str,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
include_given: bool = False,
|
include_given: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
room_id: The room the event is in.
|
||||||
event_ids: state events
|
event_ids: state events
|
||||||
include_given: include the given events in result
|
include_given: include the given events in result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An awaitable which resolve to a list of event_ids
|
list of event_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
|
# algorithm.
|
||||||
|
room = await self.get_room(room_id)
|
||||||
|
if room["has_auth_chain_index"]:
|
||||||
|
try:
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_auth_chain_ids_chains",
|
||||||
|
self._get_auth_chain_ids_using_cover_index_txn,
|
||||||
|
room_id,
|
||||||
|
event_ids,
|
||||||
|
include_given,
|
||||||
|
)
|
||||||
|
except _NoChainCoverIndex:
|
||||||
|
# For whatever reason we don't actually have a chain cover index
|
||||||
|
# for the events in question, so we fall back to the old method.
|
||||||
|
pass
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_auth_chain_ids",
|
"get_auth_chain_ids",
|
||||||
self._get_auth_chain_ids_txn,
|
self._get_auth_chain_ids_txn,
|
||||||
|
|
@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
include_given,
|
include_given,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_auth_chain_ids_using_cover_index_txn(
|
||||||
|
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
||||||
|
) -> List[str]:
|
||||||
|
"""Calculates the auth chain IDs using the chain index."""
|
||||||
|
|
||||||
|
# First we look up the chain ID/sequence numbers for the given events.
|
||||||
|
|
||||||
|
initial_events = set(event_ids)
|
||||||
|
|
||||||
|
# All the events that we've found that are reachable from the events.
|
||||||
|
seen_events = set() # type: Set[str]
|
||||||
|
|
||||||
|
# A map from chain ID to max sequence number of the given events.
|
||||||
|
event_chains = {} # type: Dict[int, int]
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT event_id, chain_id, sequence_number
|
||||||
|
FROM event_auth_chains
|
||||||
|
WHERE %s
|
||||||
|
"""
|
||||||
|
for batch in batch_iter(initial_events, 1000):
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "event_id", batch
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
for event_id, chain_id, sequence_number in txn:
|
||||||
|
seen_events.add(event_id)
|
||||||
|
event_chains[chain_id] = max(
|
||||||
|
sequence_number, event_chains.get(chain_id, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we actually have a chain ID for all the events.
|
||||||
|
events_missing_chain_info = initial_events.difference(seen_events)
|
||||||
|
if events_missing_chain_info:
|
||||||
|
# This can happen due to e.g. downgrade/upgrade of the server. We
|
||||||
|
# raise an exception and fall back to the previous algorithm.
|
||||||
|
logger.info(
|
||||||
|
"Unexpectedly found that events don't have chain IDs in room %s: %s",
|
||||||
|
room_id,
|
||||||
|
events_missing_chain_info,
|
||||||
|
)
|
||||||
|
raise _NoChainCoverIndex(room_id)
|
||||||
|
|
||||||
|
# Now we look up all links for the chains we have, adding chains that
|
||||||
|
# are reachable from any event.
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
origin_chain_id, origin_sequence_number,
|
||||||
|
target_chain_id, target_sequence_number
|
||||||
|
FROM event_auth_chain_links
|
||||||
|
WHERE %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A map from chain ID to max sequence number *reachable* from any event ID.
|
||||||
|
chains = {} # type: Dict[int, int]
|
||||||
|
|
||||||
|
# Add all linked chains reachable from initial set of chains.
|
||||||
|
for batch in batch_iter(event_chains, 1000):
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "origin_chain_id", batch
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
for (
|
||||||
|
origin_chain_id,
|
||||||
|
origin_sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
) in txn:
|
||||||
|
# chains are only reachable if the origin sequence number of
|
||||||
|
# the link is less than the max sequence number in the
|
||||||
|
# origin chain.
|
||||||
|
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
|
||||||
|
chains[target_chain_id] = max(
|
||||||
|
target_sequence_number,
|
||||||
|
chains.get(target_chain_id, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the initial set of chains, excluding the sequence corresponding to
|
||||||
|
# initial event.
|
||||||
|
for chain_id, seq_no in event_chains.items():
|
||||||
|
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
|
||||||
|
|
||||||
|
# Now for each chain we figure out the maximum sequence number reachable
|
||||||
|
# from *any* event ID. Events with a sequence less than that are in the
|
||||||
|
# auth chain.
|
||||||
|
if include_given:
|
||||||
|
results = initial_events
|
||||||
|
else:
|
||||||
|
results = set()
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# We can use `execute_values` to efficiently fetch the gaps when
|
||||||
|
# using postgres.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id
|
||||||
|
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
|
||||||
|
WHERE
|
||||||
|
c.chain_id = l.chain_id
|
||||||
|
AND sequence_number <= max_seq
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = txn.execute_values(sql, chains.items())
|
||||||
|
results.update(r for r, in rows)
|
||||||
|
else:
|
||||||
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id FROM event_auth_chains
|
||||||
|
WHERE chain_id = ? AND sequence_number <= ?
|
||||||
|
"""
|
||||||
|
for chain_id, max_no in chains.items():
|
||||||
|
txn.execute(sql, (chain_id, max_no))
|
||||||
|
results.update(r for r, in txn)
|
||||||
|
|
||||||
|
return list(results)
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(
|
def _get_auth_chain_ids_txn(
|
||||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""Calculates the auth chain IDs.
|
||||||
|
|
||||||
|
This is used when we don't have a cover index for the room.
|
||||||
|
"""
|
||||||
if include_given:
|
if include_given:
|
||||||
results = set(event_ids)
|
results = set(event_ids)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
self._chain_cover_index,
|
self._chain_cover_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_update_handler(
|
||||||
|
"purged_chain_cover",
|
||||||
|
self._purged_chain_cover_index,
|
||||||
|
)
|
||||||
|
|
||||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||||
max_stream_id = progress["max_stream_id_exclusive"]
|
max_stream_id = progress["max_stream_id_exclusive"]
|
||||||
|
|
@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
processed_count=count,
|
processed_count=count,
|
||||||
finished_room_map=finished_rooms,
|
finished_room_map=finished_rooms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
|
||||||
|
"""
|
||||||
|
A background updates that iterates over the chain cover and deletes the
|
||||||
|
chain cover for events that have been purged.
|
||||||
|
|
||||||
|
This may be due to fully purging a room or via setting a retention policy.
|
||||||
|
"""
|
||||||
|
current_event_id = progress.get("current_event_id", "")
|
||||||
|
|
||||||
|
def purged_chain_cover_txn(txn) -> int:
|
||||||
|
# The event ID from events will be null if the chain ID / sequence
|
||||||
|
# number points to a purged event.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
|
||||||
|
FROM event_auth_chains
|
||||||
|
LEFT JOIN events AS e USING (event_id)
|
||||||
|
WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (current_event_id, batch_size))
|
||||||
|
|
||||||
|
rows = txn.fetchall()
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# The event IDs and chain IDs / sequence numbers where the event has
|
||||||
|
# been purged.
|
||||||
|
unreferenced_event_ids = []
|
||||||
|
unreferenced_chain_id_tuples = []
|
||||||
|
event_id = ""
|
||||||
|
for event_id, chain_id, sequence_number, has_event in rows:
|
||||||
|
if not has_event:
|
||||||
|
unreferenced_event_ids.append((event_id,))
|
||||||
|
unreferenced_chain_id_tuples.append((chain_id, sequence_number))
|
||||||
|
|
||||||
|
# Delete the unreferenced auth chains from event_auth_chain_links and
|
||||||
|
# event_auth_chains.
|
||||||
|
txn.executemany(
|
||||||
|
"""
|
||||||
|
DELETE FROM event_auth_chains WHERE event_id = ?
|
||||||
|
""",
|
||||||
|
unreferenced_event_ids,
|
||||||
|
)
|
||||||
|
# We should also delete matching target_*, but there is no index on
|
||||||
|
# target_chain_id. Hopefully any purged events are due to a room
|
||||||
|
# being fully purged and they will be removed from the origin_*
|
||||||
|
# searches.
|
||||||
|
txn.executemany(
|
||||||
|
"""
|
||||||
|
DELETE FROM event_auth_chain_links WHERE
|
||||||
|
origin_chain_id = ? AND origin_sequence_number = ?
|
||||||
|
""",
|
||||||
|
unreferenced_chain_id_tuples,
|
||||||
|
)
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
"current_event_id": event_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
|
txn, "purged_chain_cover", progress
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
result = await self.db_pool.runInteraction(
|
||||||
|
"_purged_chain_cover_index",
|
||||||
|
purged_chain_cover_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
await self.db_pool.updates._end_background_update("purged_chain_cover")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
set[str]: The events we have already seen.
|
set[str]: The events we have already seen.
|
||||||
"""
|
"""
|
||||||
results = set()
|
# if the event cache contains the event, obviously we've seen it.
|
||||||
|
results = {x for x in event_ids if self._get_event_cache.contains(x)}
|
||||||
|
|
||||||
def have_seen_events_txn(txn, chunk):
|
def have_seen_events_txn(txn, chunk):
|
||||||
sql = "SELECT event_id FROM events as e WHERE "
|
sql = "SELECT event_id FROM events as e WHERE "
|
||||||
|
|
@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
txn.database_engine, "e.event_id", chunk
|
txn.database_engine, "e.event_id", chunk
|
||||||
)
|
)
|
||||||
txn.execute(sql + clause, args)
|
txn.execute(sql + clause, args)
|
||||||
for (event_id,) in txn:
|
results.update(row[0] for row in txn)
|
||||||
results.add(event_id)
|
|
||||||
|
|
||||||
# break the input up into chunks of 100
|
for chunk in batch_iter((x for x in event_ids if x not in results), 100):
|
||||||
input_iterator = iter(event_ids)
|
|
||||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"have_seen_events", have_seen_events_txn, chunk
|
"have_seen_events", have_seen_events_txn, chunk
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
txn.executemany(
|
txn.executemany(
|
||||||
"""
|
"""
|
||||||
DELETE FROM event_auth_chain_links WHERE
|
DELETE FROM event_auth_chain_links WHERE
|
||||||
(origin_chain_id = ? AND origin_sequence_number = ?) OR
|
origin_chain_id = ? AND origin_sequence_number = ?
|
||||||
(target_chain_id = ? AND target_sequence_number = ?)
|
|
||||||
""",
|
""",
|
||||||
(
|
referenced_chain_id_tuples,
|
||||||
(chain_id, seq_num, chain_id, seq_num)
|
|
||||||
for (chain_id, seq_num) in referenced_chain_id_tuples
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now we delete tables which lack an index on room_id but have one on event_id
|
# Now we delete tables which lack an index on room_id but have one on event_id
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
|
@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
async def user_delete_access_tokens(
|
async def user_delete_access_tokens(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
except_token_id: Optional[str] = None,
|
except_token_id: Optional[int] = None,
|
||||||
device_id: Optional[str] = None,
|
device_id: Optional[str] = None,
|
||||||
) -> List[Tuple[str, int, Optional[str]]]:
|
) -> List[Tuple[str, int, Optional[str]]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
items = keyvalues.items()
|
items = keyvalues.items()
|
||||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||||
values = [v for _, v in items]
|
values = [v for _, v in items] # type: List[Union[str, int]]
|
||||||
if except_token_id:
|
if except_token_id:
|
||||||
where_clause += " AND id != ?"
|
where_clause += " AND id != ?"
|
||||||
values.append(except_token_id)
|
values.append(except_token_id)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(5910, 'purged_chain_cover', '{}');
|
||||||
|
|
@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore):
|
||||||
|
|
||||||
self.db_pool.simple_upsert_many_txn(
|
self.db_pool.simple_upsert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
"destination_rooms",
|
table="destination_rooms",
|
||||||
["destination", "room_id"],
|
key_names=("destination", "room_id"),
|
||||||
rows,
|
key_values=rows,
|
||||||
["stream_ordering"],
|
value_names=["stream_ordering"],
|
||||||
[(stream_ordering,)] * len(rows),
|
value_values=[(stream_ordering,)] * len(rows),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_destination_last_successful_stream_ordering(
|
async def get_destination_last_successful_stream_ordering(
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,14 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
from zope.interface import Interface
|
||||||
|
|
||||||
|
from twisted.internet.interfaces import (
|
||||||
|
IReactorCore,
|
||||||
|
IReactorPluggableNameResolver,
|
||||||
|
IReactorTCP,
|
||||||
|
IReactorTime,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
@ -67,33 +75,40 @@ MutableStateMap = MutableMapping[StateKey, T]
|
||||||
JsonDict = Dict[str, Any]
|
JsonDict = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Requester(
|
# Note that this seems to require inheriting *directly* from Interface in order
|
||||||
namedtuple(
|
# for mypy-zope to realize it is an interface.
|
||||||
"Requester",
|
class ISynapseReactor(
|
||||||
[
|
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
|
||||||
"user",
|
|
||||||
"access_token_id",
|
|
||||||
"is_guest",
|
|
||||||
"shadow_banned",
|
|
||||||
"device_id",
|
|
||||||
"app_service",
|
|
||||||
"authenticated_entity",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
|
"""The interfaces necessary for Synapse to function."""
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True)
|
||||||
|
class Requester:
|
||||||
"""
|
"""
|
||||||
Represents the user making a request
|
Represents the user making a request
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
user (UserID): id of the user making the request
|
user: id of the user making the request
|
||||||
access_token_id (int|None): *ID* of the access token used for this
|
access_token_id: *ID* of the access token used for this
|
||||||
request, or None if it came via the appservice API or similar
|
request, or None if it came via the appservice API or similar
|
||||||
is_guest (bool): True if the user making this request is a guest user
|
is_guest: True if the user making this request is a guest user
|
||||||
shadow_banned (bool): True if the user making this request has been shadow-banned.
|
shadow_banned: True if the user making this request has been shadow-banned.
|
||||||
device_id (str|None): device_id which was set at authentication time
|
device_id: device_id which was set at authentication time
|
||||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
app_service: the AS requesting on behalf of the user
|
||||||
|
authenticated_entity: The entity that authenticated when making the request.
|
||||||
|
This is different to the user_id when an admin user or the server is
|
||||||
|
"puppeting" the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
user = attr.ib(type="UserID")
|
||||||
|
access_token_id = attr.ib(type=Optional[int])
|
||||||
|
is_guest = attr.ib(type=bool)
|
||||||
|
shadow_banned = attr.ib(type=bool)
|
||||||
|
device_id = attr.ib(type=Optional[str])
|
||||||
|
app_service = attr.ib(type=Optional["ApplicationService"])
|
||||||
|
authenticated_entity = attr.ib(type=str)
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self):
|
||||||
"""Converts self to a type that can be serialized as JSON, and then
|
"""Converts self to a type that can be serialized as JSON, and then
|
||||||
deserialized by `deserialize`
|
deserialized by `deserialize`
|
||||||
|
|
@ -141,23 +156,23 @@ class Requester(
|
||||||
def create_requester(
|
def create_requester(
|
||||||
user_id: Union[str, "UserID"],
|
user_id: Union[str, "UserID"],
|
||||||
access_token_id: Optional[int] = None,
|
access_token_id: Optional[int] = None,
|
||||||
is_guest: Optional[bool] = False,
|
is_guest: bool = False,
|
||||||
shadow_banned: Optional[bool] = False,
|
shadow_banned: bool = False,
|
||||||
device_id: Optional[str] = None,
|
device_id: Optional[str] = None,
|
||||||
app_service: Optional["ApplicationService"] = None,
|
app_service: Optional["ApplicationService"] = None,
|
||||||
authenticated_entity: Optional[str] = None,
|
authenticated_entity: Optional[str] = None,
|
||||||
):
|
) -> Requester:
|
||||||
"""
|
"""
|
||||||
Create a new ``Requester`` object
|
Create a new ``Requester`` object
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str|UserID): id of the user making the request
|
user_id: id of the user making the request
|
||||||
access_token_id (int|None): *ID* of the access token used for this
|
access_token_id: *ID* of the access token used for this
|
||||||
request, or None if it came via the appservice API or similar
|
request, or None if it came via the appservice API or similar
|
||||||
is_guest (bool): True if the user making this request is a guest user
|
is_guest: True if the user making this request is a guest user
|
||||||
shadow_banned (bool): True if the user making this request is shadow-banned.
|
shadow_banned: True if the user making this request is shadow-banned.
|
||||||
device_id (str|None): device_id which was set at authentication time
|
device_id: device_id which was set at authentication time
|
||||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
app_service: the AS requesting on behalf of the user
|
||||||
authenticated_entity: The entity that authenticated when making the request.
|
authenticated_entity: The entity that authenticated when making the request.
|
||||||
This is different to the user_id when an admin user or the server is
|
This is different to the user_id when an admin user or the server is
|
||||||
"puppeting" the user.
|
"puppeting" the user.
|
||||||
|
|
|
||||||
|
|
@ -76,11 +76,16 @@ class ObservableDeferred:
|
||||||
def callback(r):
|
def callback(r):
|
||||||
object.__setattr__(self, "_result", (True, r))
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
|
observer = self._observers.pop()
|
||||||
try:
|
try:
|
||||||
# TODO: Handle errors here.
|
observer.callback(r)
|
||||||
self._observers.pop().callback(r)
|
except Exception as e:
|
||||||
except Exception:
|
logger.exception(
|
||||||
pass
|
"%r threw an exception on .callback(%r), ignoring...",
|
||||||
|
observer,
|
||||||
|
r,
|
||||||
|
exc_info=e,
|
||||||
|
)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def errback(f):
|
def errback(f):
|
||||||
|
|
@ -90,11 +95,16 @@ class ObservableDeferred:
|
||||||
# traces when we `await` on one of the observer deferreds.
|
# traces when we `await` on one of the observer deferreds.
|
||||||
f.value.__failure__ = f
|
f.value.__failure__ = f
|
||||||
|
|
||||||
|
observer = self._observers.pop()
|
||||||
try:
|
try:
|
||||||
# TODO: Handle errors here.
|
observer.errback(f)
|
||||||
self._observers.pop().errback(f)
|
except Exception as e:
|
||||||
except Exception:
|
logger.exception(
|
||||||
pass
|
"%r threw an exception on .errback(%r), ignoring...",
|
||||||
|
observer,
|
||||||
|
f,
|
||||||
|
exc_info=e,
|
||||||
|
)
|
||||||
|
|
||||||
if consumeErrors:
|
if consumeErrors:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -13,17 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
|
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from synapse.app.homeserver import HomeServer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
|
||||||
used rather than trying to compute a new response.
|
used rather than trying to compute a new response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
|
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
|
||||||
# Requests that haven't finished yet.
|
# Requests that haven't finished yet.
|
||||||
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = clock
|
||||||
self.timeout_sec = timeout_ms / 1000.0
|
self.timeout_sec = timeout_ms / 1000.0
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
|
||||||
89
synapse/util/macaroons.py
Normal file
89
synapse/util/macaroons.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 Quentin Gliech
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Utilities for manipulating macaroons"""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
|
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||||
|
|
||||||
|
|
||||||
|
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
|
"""Extracts a caveat value from a macaroon token.
|
||||||
|
|
||||||
|
Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
|
||||||
|
and returns the extracted value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
macaroon: the token
|
||||||
|
key: the key of the caveat to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MacaroonVerificationFailedException: if there are conflicting values for the
|
||||||
|
caveat in the macaroon, or if the caveat was not found in the macaroon.
|
||||||
|
"""
|
||||||
|
prefix = key + " = "
|
||||||
|
result = None # type: Optional[str]
|
||||||
|
for caveat in macaroon.caveats:
|
||||||
|
if not caveat.caveat_id.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
val = caveat.caveat_id[len(prefix) :]
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
# first time we found this caveat: record the value
|
||||||
|
result = val
|
||||||
|
elif val != result:
|
||||||
|
# on subsequent occurrences, raise if the value is different.
|
||||||
|
raise MacaroonVerificationFailedException(
|
||||||
|
"Conflicting values for caveat " + key
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If the caveat is not there, we raise a MacaroonVerificationFailedException.
|
||||||
|
# Note that it is insecure to generate a macaroon without all the caveats you
|
||||||
|
# might need (because there is nothing stopping people from adding extra caveats),
|
||||||
|
# so if the caveat isn't there, something odd must be going on.
|
||||||
|
raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
|
||||||
|
|
||||||
|
|
||||||
|
def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
|
||||||
|
"""Make a macaroon verifier which accepts 'time' caveats
|
||||||
|
|
||||||
|
Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
|
||||||
|
the given macaroon verifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: the macaroon verifier
|
||||||
|
get_time_ms: a callable which will return the timestamp after which the caveat
|
||||||
|
should be considered expired. Normally the current time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def verify_expiry_caveat(caveat: str):
|
||||||
|
time_msec = get_time_ms()
|
||||||
|
prefix = "time < "
|
||||||
|
if not caveat.startswith(prefix):
|
||||||
|
return False
|
||||||
|
expiry = int(caveat[len(prefix) :])
|
||||||
|
return time_msec < expiry
|
||||||
|
|
||||||
|
v.satisfy_general(verify_expiry_caveat)
|
||||||
|
|
@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
|
||||||
from synapse.federation.units import Edu
|
from synapse.federation.units import Edu
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from tests.test_utils import event_injection, make_awaitable
|
from tests.test_utils import event_injection, make_awaitable
|
||||||
from tests.unittest import FederatingHomeserverTestCase, override_config
|
from tests.unittest import FederatingHomeserverTestCase, override_config
|
||||||
|
|
@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
data = json_cb()
|
data = json_cb()
|
||||||
self.failed_pdus.extend(data["pdus"])
|
self.failed_pdus.extend(data["pdus"])
|
||||||
raise IOError("Failed to connect because this is a test!")
|
raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination)
|
||||||
|
|
||||||
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
|
def get_destination_room(self, room: str, destination: str = "host2") -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
5
tests/handlers/oidc_test_key.p8
Normal file
5
tests/handlers/oidc_test_key.p8
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
|
||||||
|
Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
|
||||||
|
P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
|
||||||
|
-----END PRIVATE KEY-----
|
||||||
4
tests/handlers/oidc_test_key.pub.pem
Normal file
4
tests/handlers/oidc_test_key.pub.pem
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
|
||||||
|
QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
|
||||||
|
-----END PUBLIC KEY-----
|
||||||
|
|
@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
user_id = self.get_success(
|
"a_user", "", 5000
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
|
||||||
)
|
)
|
||||||
self.assertEqual("a_user", user_id)
|
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||||
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
self.assertEqual("", res.auth_provider_id)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.reactor.advance(6)
|
self.reactor.advance(6)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
|
self.auth_handler.validate_short_term_login_token(token),
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_short_term_login_token_gives_auth_provider(self):
|
||||||
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"a_user", auth_provider_id="my_idp"
|
||||||
|
)
|
||||||
|
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||||
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
self.assertEqual("my_idp", res.auth_provider_id)
|
||||||
|
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"a_user", "", 5000
|
||||||
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
user_id = self.get_success(
|
res = self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
|
||||||
macaroon.serialize()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.assertEqual("a_user", user_id)
|
self.assertEqual("a_user", res.user_id)
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
# user_id.
|
# user_id.
|
||||||
macaroon.add_first_party_caveat("user_id = b_user")
|
macaroon.add_first_party_caveat("user_id = b_user")
|
||||||
|
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
|
||||||
macaroon.serialize()
|
|
||||||
),
|
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
return_value=make_awaitable(self.large_number_of_users)
|
return_value=make_awaitable(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
),
|
),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
|
|
@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
),
|
),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
|
|
@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
return_value=make_awaitable(self.small_number_of_users)
|
return_value=make_awaitable(self.small_number_of_users)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
self.auth_handler.validate_short_term_login_token(
|
||||||
self._get_macaroon().serialize()
|
self._get_macaroon().serialize()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_macaroon(self):
|
def _get_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"user_a", "", 5000
|
||||||
|
)
|
||||||
return pymacaroons.Macaroon.deserialize(token)
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_existing_user(self):
|
def test_map_cas_user_to_existing_user(self):
|
||||||
|
|
@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
|
|
@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_invalid_localpart(self):
|
def test_map_cas_user_to_invalid_localpart(self):
|
||||||
|
|
@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
|
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
|
@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
import os
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from mock import ANY, Mock, patch
|
from mock import ANY, Mock, patch
|
||||||
|
|
@ -23,6 +23,7 @@ import pymacaroons
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
|
||||||
JWKS_URI = ISSUER + ".well-known/jwks.json"
|
JWKS_URI = ISSUER + ".well-known/jwks.json"
|
||||||
|
|
||||||
# config for common cases
|
# config for common cases
|
||||||
COMMON_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"client_secret": CLIENT_SECRET,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"scopes": SCOPES,
|
||||||
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# extends the default config with explicit OAuth2 endpoints instead of using discovery
|
||||||
|
EXPLICIT_ENDPOINT_CONFIG = {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"discover": False,
|
"discover": False,
|
||||||
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
|
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
|
||||||
"token_endpoint": TOKEN_ENDPOINT,
|
"token_endpoint": TOKEN_ENDPOINT,
|
||||||
|
|
@ -107,6 +119,32 @@ async def get_json(url):
|
||||||
return {"keys": []}
|
return {"keys": []}
|
||||||
|
|
||||||
|
|
||||||
|
def _key_file_path() -> str:
|
||||||
|
"""path to a file containing the private half of a test key"""
|
||||||
|
|
||||||
|
# this key was generated with:
|
||||||
|
# openssl ecparam -name prime256v1 -genkey -noout |
|
||||||
|
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
|
||||||
|
#
|
||||||
|
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
|
||||||
|
# that's what Apple use, and we want to be sure that we work with Apple's keys.
|
||||||
|
#
|
||||||
|
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
|
||||||
|
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the
|
||||||
|
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
|
||||||
|
# really need to care about any of that.)
|
||||||
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
|
||||||
|
|
||||||
|
|
||||||
|
def _public_key_file_path() -> str:
|
||||||
|
"""path to a file containing the public half of a test key"""
|
||||||
|
# this was generated with:
|
||||||
|
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
|
||||||
|
#
|
||||||
|
# See above about where oidc_test_key.p8 came from
|
||||||
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
|
||||||
|
|
||||||
|
|
||||||
class OidcHandlerTestCase(HomeserverTestCase):
|
class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
if not HAS_OIDC:
|
if not HAS_OIDC:
|
||||||
skip = "requires OIDC"
|
skip = "requires OIDC"
|
||||||
|
|
@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self):
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
oidc_config = {
|
|
||||||
"enabled": True,
|
|
||||||
"client_id": CLIENT_ID,
|
|
||||||
"client_secret": CLIENT_SECRET,
|
|
||||||
"issuer": ISSUER,
|
|
||||||
"scopes": SCOPES,
|
|
||||||
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update this config with what's in the default config so that
|
|
||||||
# override_config works as expected.
|
|
||||||
oidc_config.update(config.get("oidc_config", {}))
|
|
||||||
config["oidc_config"] = oidc_config
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.render_error.reset_mock()
|
self.render_error.reset_mock()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"discover": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||||
def test_discovery(self):
|
def test_discovery(self):
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
|
|
@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self):
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self):
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = self.get_success(self.provider.load_jwks())
|
jwks = self.get_success(self.provider.load_jwks())
|
||||||
|
|
@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self):
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.provider
|
h = self.provider
|
||||||
|
|
@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# Shouldn't raise with a valid userinfo, even without jwks
|
# Shouldn't raise with a valid userinfo, even without jwks
|
||||||
force_load_metadata()
|
force_load_metadata()
|
||||||
|
|
||||||
@override_config({"oidc_config": {"skip_verification": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||||
def test_skip_verification(self):
|
def test_skip_verification(self):
|
||||||
"""Provider metadata validation can be disabled by config."""
|
"""Provider metadata validation can be disabled by config."""
|
||||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
get_awaitable_result(self.provider.load_metadata())
|
get_awaitable_result(self.provider.load_metadata())
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self):
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["cookies"])
|
req = Mock(spec=["cookies"])
|
||||||
|
|
@ -360,20 +387,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(name, b"oidc_session")
|
self.assertEqual(name, b"oidc_session")
|
||||||
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
||||||
state = self.handler._token_generator._get_value_from_macaroon(
|
state = get_value_from_macaroon(macaroon, "state")
|
||||||
macaroon, "state"
|
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||||
)
|
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||||
nonce = self.handler._token_generator._get_value_from_macaroon(
|
|
||||||
macaroon, "nonce"
|
|
||||||
)
|
|
||||||
redirect = self.handler._token_generator._get_value_from_macaroon(
|
|
||||||
macaroon, "client_redirect_url"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(params["state"], [state])
|
self.assertEqual(params["state"], [state])
|
||||||
self.assertEqual(params["nonce"], [nonce])
|
self.assertEqual(params["nonce"], [nonce])
|
||||||
self.assertEqual(redirect, "http://client/redirect")
|
self.assertEqual(redirect, "http://client/redirect")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_error(self):
|
def test_callback_error(self):
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
|
@ -385,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_client", "some description")
|
self.assertRenderedError("invalid_client", "some description")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback(self):
|
def test_callback(self):
|
||||||
"""Code callback works and display errors if something went wrong.
|
"""Code callback works and display errors if something went wrong.
|
||||||
|
|
||||||
|
|
@ -434,7 +457,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None, new_user=True
|
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
|
|
@ -465,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, request, client_redirect_url, None, new_user=False
|
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_not_called()
|
self.provider._parse_id_token.assert_not_called()
|
||||||
|
|
@ -486,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_session(self):
|
def test_callback_session(self):
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||||
|
|
@ -528,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
|
@override_config(
|
||||||
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||||
|
)
|
||||||
def test_exchange_code(self):
|
def test_exchange_code(self):
|
||||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
|
|
@ -613,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"client_auth_method": "client_secret_post",
|
||||||
|
"client_secret_jwt_key": {
|
||||||
|
"key_file": _key_file_path(),
|
||||||
|
"jwt_header": {"alg": "ES256", "kid": "ABC789"},
|
||||||
|
"jwt_payload": {"iss": "DEFGHI"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_exchange_code_jwt_key(self):
|
||||||
|
"""Test that code exchange works with a JWK client secret."""
|
||||||
|
from authlib.jose import jwt
|
||||||
|
|
||||||
|
token = {"type": "bearer"}
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
code = "code"
|
||||||
|
|
||||||
|
# advance the clock a bit before we start, so we aren't working with zero
|
||||||
|
# timestamps.
|
||||||
|
self.reactor.advance(1000)
|
||||||
|
start_time = self.reactor.seconds()
|
||||||
|
ret = self.get_success(self.provider._exchange_code(code))
|
||||||
|
|
||||||
|
self.assertEqual(ret, token)
|
||||||
|
|
||||||
|
# the request should have hit the token endpoint
|
||||||
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
self.assertEqual(kwargs["method"], "POST")
|
||||||
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
|
||||||
|
|
||||||
|
# the client secret provided to the should be a jwt which can be checked with
|
||||||
|
# the public key
|
||||||
|
args = parse_qs(kwargs["data"].decode("utf-8"))
|
||||||
|
secret = args["client_secret"][0]
|
||||||
|
with open(_public_key_file_path()) as f:
|
||||||
|
key = f.read()
|
||||||
|
claims = jwt.decode(secret, key)
|
||||||
|
self.assertEqual(claims.header["kid"], "ABC789")
|
||||||
|
self.assertEqual(claims["aud"], ISSUER)
|
||||||
|
self.assertEqual(claims["iss"], "DEFGHI")
|
||||||
|
self.assertEqual(claims["sub"], CLIENT_ID)
|
||||||
|
self.assertEqual(claims["iat"], start_time)
|
||||||
|
self.assertGreater(claims["exp"], start_time)
|
||||||
|
|
||||||
|
# check the rest of the POSTed data
|
||||||
|
self.assertEqual(args["grant_type"], ["authorization_code"])
|
||||||
|
self.assertEqual(args["code"], [code])
|
||||||
|
self.assertEqual(args["client_id"], [CLIENT_ID])
|
||||||
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"client_auth_method": "none",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_exchange_code_no_auth(self):
|
||||||
|
"""Test that code exchange works with no client secret."""
|
||||||
|
token = {"type": "bearer"}
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
code = "code"
|
||||||
|
ret = self.get_success(self.provider._exchange_code(code))
|
||||||
|
|
||||||
|
self.assertEqual(ret, token)
|
||||||
|
|
||||||
|
# the request should have hit the token endpoint
|
||||||
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
self.assertEqual(kwargs["method"], "POST")
|
||||||
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
|
||||||
|
|
||||||
|
# check the POSTed data
|
||||||
|
args = parse_qs(kwargs["data"].decode("utf-8"))
|
||||||
|
self.assertEqual(args["grant_type"], ["authorization_code"])
|
||||||
|
self.assertEqual(args["code"], [code])
|
||||||
|
self.assertEqual(args["client_id"], [CLIENT_ID])
|
||||||
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"module": __name__ + ".TestMappingProviderExtra"
|
"module": __name__ + ".TestMappingProviderExtra"
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -651,12 +773,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@foo:test",
|
"@foo:test",
|
||||||
|
"oidc",
|
||||||
request,
|
request,
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
{"phone": "1234567"},
|
{"phone": "1234567"},
|
||||||
new_user=True,
|
new_user=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_user(self):
|
def test_map_userinfo_to_user(self):
|
||||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
|
|
@ -668,7 +792,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", ANY, ANY, None, new_user=True
|
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -679,7 +803,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user_2:test", ANY, ANY, None, new_user=True
|
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -697,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"Mapping provider does not support de-duplicating Matrix IDs",
|
"Mapping provider does not support de-duplicating Matrix IDs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"allow_existing_users": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||||
def test_map_userinfo_to_existing_user(self):
|
def test_map_userinfo_to_existing_user(self):
|
||||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
|
|
@ -716,14 +840,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -738,7 +862,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None, new_user=False
|
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -774,9 +898,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@TEST_USER_2:test", ANY, ANY, None, new_user=False
|
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self):
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
|
@ -787,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"module": __name__ + ".TestMappingProviderFailures"
|
"module": __name__ + ".TestMappingProviderFailures"
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -810,7 +936,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", ANY, ANY, None, new_user=True
|
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -834,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_empty_localpart(self):
|
def test_empty_localpart(self):
|
||||||
"""Attempts to map onto an empty localpart should be rejected."""
|
"""Attempts to map onto an empty localpart should be rejected."""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
|
|
@ -846,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"config": {"localpart_template": "{{ user.username }}"}
|
"config": {"localpart_template": "{{ user.username }}"}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -866,7 +994,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
state: str,
|
state: str,
|
||||||
nonce: str,
|
nonce: str,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
ui_auth_session_id: Optional[str] = None,
|
ui_auth_session_id: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
from synapse.handlers.oidc_handler import OidcSessionData
|
from synapse.handlers.oidc_handler import OidcSessionData
|
||||||
|
|
||||||
|
|
@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo(
|
||||||
idp_id="oidc",
|
idp_id="oidc",
|
||||||
nonce="nonce",
|
nonce="nonce",
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
|
ui_auth_session_id="",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
request = _build_callback_request("code", state, session)
|
request = _build_callback_request("code", state, session)
|
||||||
|
|
|
||||||
|
|
@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertTrue(requester.shadow_banned)
|
self.assertTrue(requester.shadow_banned)
|
||||||
|
|
||||||
|
def test_spam_checker_receives_sso_type(self):
|
||||||
|
"""Test rejecting registration based on SSO type"""
|
||||||
|
|
||||||
|
class BanBadIdPUser:
|
||||||
|
def check_registration_for_spam(
|
||||||
|
self, email_threepid, username, request_info, auth_provider_id=None
|
||||||
|
):
|
||||||
|
# Reject any user coming from CAS and whose username contains profanity
|
||||||
|
if auth_provider_id == "cas" and "flimflob" in username:
|
||||||
|
return RegistrationBehaviour.DENY
|
||||||
|
return RegistrationBehaviour.ALLOW
|
||||||
|
|
||||||
|
# Configure a spam checker that denies a certain user on a specific IdP
|
||||||
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
spam_checker.spam_checkers = [BanBadIdPUser()]
|
||||||
|
|
||||||
|
f = self.get_failure(
|
||||||
|
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
|
||||||
|
SynapseError,
|
||||||
|
)
|
||||||
|
exception = f.value
|
||||||
|
|
||||||
|
# We return 429 from the spam checker for denied registrations
|
||||||
|
self.assertIsInstance(exception, SynapseError)
|
||||||
|
self.assertEqual(exception.code, 429)
|
||||||
|
|
||||||
|
# Check the same username can register using SAML
|
||||||
|
self.get_success(
|
||||||
|
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
|
||||||
|
)
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self, requester, localpart, displayname, password_hash=None
|
self, requester, localpart, displayname, password_hash=None
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
|
|
@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "", None, new_user=False
|
"@test_user:test", "saml", request, "", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
|
|
@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._handle_authn_response(request, saml_response, "")
|
self.handler._handle_authn_response(request, saml_response, "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "", None, new_user=False
|
"@test_user:test", "saml", request, "", None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
|
|
@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", request, "", None, new_user=True
|
"@test_user1:test", "saml", request, "", None, new_user=True
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,23 @@ from io import BytesIO
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from netaddr import IPSet
|
||||||
|
|
||||||
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.client import ResponseDone
|
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||||
|
from twisted.web.client import Agent, ResponseDone
|
||||||
from twisted.web.iweb import UNKNOWN_LENGTH
|
from twisted.web.iweb import UNKNOWN_LENGTH
|
||||||
|
|
||||||
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.client import (
|
||||||
|
BlacklistingAgentWrapper,
|
||||||
|
BlacklistingReactorWrapper,
|
||||||
|
BodyExceededMaxSize,
|
||||||
|
read_body_with_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.server import FakeTransport, get_clock
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
||||||
|
|
||||||
# The data is never consumed.
|
# The data is never consumed.
|
||||||
self.assertEqual(result.getvalue(), b"")
|
self.assertEqual(result.getvalue(), b"")
|
||||||
|
|
||||||
|
|
||||||
|
class BlacklistingAgentTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor, self.clock = get_clock()
|
||||||
|
|
||||||
|
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
||||||
|
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
|
||||||
|
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
|
||||||
|
|
||||||
|
# Configure the reactor's DNS resolver.
|
||||||
|
for (domain, ip) in (
|
||||||
|
(self.safe_domain, self.safe_ip),
|
||||||
|
(self.unsafe_domain, self.unsafe_ip),
|
||||||
|
(self.allowed_domain, self.allowed_ip),
|
||||||
|
):
|
||||||
|
self.reactor.lookups[domain.decode()] = ip.decode()
|
||||||
|
self.reactor.lookups[ip.decode()] = ip.decode()
|
||||||
|
|
||||||
|
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
||||||
|
self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
||||||
|
|
||||||
|
def test_reactor(self):
|
||||||
|
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
||||||
|
agent = Agent(
|
||||||
|
BlacklistingReactorWrapper(
|
||||||
|
self.reactor,
|
||||||
|
ip_whitelist=self.ip_whitelist,
|
||||||
|
ip_blacklist=self.ip_blacklist,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The unsafe domains and IPs should be rejected.
|
||||||
|
for domain in (self.unsafe_domain, self.unsafe_ip):
|
||||||
|
self.failureResultOf(
|
||||||
|
agent.request(b"GET", b"http://" + domain), DNSLookupError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The safe domains IPs should be accepted.
|
||||||
|
for domain in (
|
||||||
|
self.safe_domain,
|
||||||
|
self.allowed_domain,
|
||||||
|
self.safe_ip,
|
||||||
|
self.allowed_ip,
|
||||||
|
):
|
||||||
|
d = agent.request(b"GET", b"http://" + domain)
|
||||||
|
|
||||||
|
# Grab the latest TCP connection.
|
||||||
|
(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
client_factory,
|
||||||
|
_timeout,
|
||||||
|
_bindAddress,
|
||||||
|
) = self.reactor.tcpClients[-1]
|
||||||
|
|
||||||
|
# Make the connection and pump data through it.
|
||||||
|
client = client_factory.buildProtocol(None)
|
||||||
|
server = AccumulatingProtocol()
|
||||||
|
server.makeConnection(FakeTransport(client, self.reactor))
|
||||||
|
client.makeConnection(FakeTransport(server, self.reactor))
|
||||||
|
client.dataReceived(
|
||||||
|
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.successResultOf(d)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
|
def test_agent(self):
|
||||||
|
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
||||||
|
agent = BlacklistingAgentWrapper(
|
||||||
|
Agent(self.reactor),
|
||||||
|
ip_whitelist=self.ip_whitelist,
|
||||||
|
ip_blacklist=self.ip_blacklist,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The unsafe IPs should be rejected.
|
||||||
|
self.failureResultOf(
|
||||||
|
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The safe and unsafe domains and safe IPs should be accepted.
|
||||||
|
for domain in (
|
||||||
|
self.safe_domain,
|
||||||
|
self.unsafe_domain,
|
||||||
|
self.allowed_domain,
|
||||||
|
self.safe_ip,
|
||||||
|
self.allowed_ip,
|
||||||
|
):
|
||||||
|
d = agent.request(b"GET", b"http://" + domain)
|
||||||
|
|
||||||
|
# Grab the latest TCP connection.
|
||||||
|
(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
client_factory,
|
||||||
|
_timeout,
|
||||||
|
_bindAddress,
|
||||||
|
) = self.reactor.tcpClients[-1]
|
||||||
|
|
||||||
|
# Make the connection and pump data through it.
|
||||||
|
client = client_factory.buildProtocol(None)
|
||||||
|
server = AccumulatingProtocol()
|
||||||
|
server.makeConnection(FakeTransport(client, self.reactor))
|
||||||
|
client.makeConnection(FakeTransport(server, self.reactor))
|
||||||
|
client.dataReceived(
|
||||||
|
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.successResultOf(d)
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
|
|
||||||
|
|
@ -13,15 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
|
||||||
|
|
||||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
from twisted.internet.task import LoopingCall
|
from twisted.internet.task import LoopingCall
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.app.generic_worker import (
|
from synapse.app.generic_worker import (
|
||||||
GenericWorkerReplicationHandler,
|
GenericWorkerReplicationHandler,
|
||||||
|
|
@ -32,7 +31,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
|
||||||
from synapse.replication.http import ReplicationRestResource
|
from synapse.replication.http import ReplicationRestResource
|
||||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import (
|
||||||
|
ReplicationStreamProtocolFactory,
|
||||||
|
ServerReplicationStreamProtocol,
|
||||||
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
|
@ -59,7 +61,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# build a replication server
|
# build a replication server
|
||||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||||
self.streamer = hs.get_replication_streamer()
|
self.streamer = hs.get_replication_streamer()
|
||||||
self.server = server_factory.buildProtocol(None)
|
self.server = server_factory.buildProtocol(
|
||||||
|
None
|
||||||
|
) # type: ServerReplicationStreamProtocol
|
||||||
|
|
||||||
# Make a new HomeServer object for the worker
|
# Make a new HomeServer object for the worker
|
||||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||||
|
|
@ -152,12 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
request_factory = OneShotRequestFactory()
|
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor)
|
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
||||||
channel.requestFactory = request_factory
|
|
||||||
channel.site = self.site
|
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
|
@ -179,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
server_to_client_transport.loseConnection()
|
server_to_client_transport.loseConnection()
|
||||||
client_to_server_transport.loseConnection()
|
client_to_server_transport.loseConnection()
|
||||||
|
|
||||||
return request_factory.request
|
return channel.request
|
||||||
|
|
||||||
def assert_request_is_get_repl_stream_updates(
|
def assert_request_is_get_repl_stream_updates(
|
||||||
self, request: SynapseRequest, stream_name: str
|
self, request: SynapseRequest, stream_name: str
|
||||||
|
|
@ -188,8 +188,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
fetching updates for given stream.
|
fetching updates for given stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
path = request.path # type: bytes # type: ignore
|
||||||
self.assertRegex(
|
self.assertRegex(
|
||||||
request.path,
|
path,
|
||||||
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
|
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
|
||||||
% (stream_name.encode("ascii"),),
|
% (stream_name.encode("ascii"),),
|
||||||
)
|
)
|
||||||
|
|
@ -232,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
if self.hs.config.redis.redis_enabled:
|
if self.hs.config.redis.redis_enabled:
|
||||||
# Handle attempts to connect to fake redis server.
|
# Handle attempts to connect to fake redis server.
|
||||||
self.reactor.add_tcp_client_callback(
|
self.reactor.add_tcp_client_callback(
|
||||||
"localhost",
|
b"localhost",
|
||||||
6379,
|
6379,
|
||||||
self.connect_any_redis_attempts,
|
self.connect_any_redis_attempts,
|
||||||
)
|
)
|
||||||
|
|
@ -387,12 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
request_factory = OneShotRequestFactory()
|
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor)
|
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
||||||
channel.requestFactory = request_factory
|
|
||||||
channel.site = self._hs_to_site[hs]
|
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
|
@ -418,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
clients = self.reactor.tcpClients
|
clients = self.reactor.tcpClients
|
||||||
while clients:
|
while clients:
|
||||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||||
self.assertEqual(host, "localhost")
|
self.assertEqual(host, b"localhost")
|
||||||
self.assertEqual(port, 6379)
|
self.assertEqual(port, 6379)
|
||||||
|
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
@ -450,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
||||||
|
|
||||||
@attr.s()
|
|
||||||
class OneShotRequestFactory:
|
|
||||||
"""A simple request factory that generates a single `SynapseRequest` and
|
|
||||||
stores it for future use. Can only be used once.
|
|
||||||
"""
|
|
||||||
|
|
||||||
request = attr.ib(default=None)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
assert self.request is None
|
|
||||||
|
|
||||||
self.request = SynapseRequest(*args, **kwargs)
|
|
||||||
return self.request
|
|
||||||
|
|
||||||
|
|
||||||
class _PushHTTPChannel(HTTPChannel):
|
class _PushHTTPChannel(HTTPChannel):
|
||||||
"""A HTTPChannel that wraps pull producers to push producers.
|
"""A HTTPChannel that wraps pull producers to push producers.
|
||||||
|
|
||||||
|
|
@ -475,9 +457,13 @@ class _PushHTTPChannel(HTTPChannel):
|
||||||
makes it very hard to test.
|
makes it very hard to test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reactor: IReactorTime):
|
def __init__(
|
||||||
|
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reactor = reactor
|
self.reactor = reactor
|
||||||
|
self.requestFactory = request_factory
|
||||||
|
self.site = site
|
||||||
|
|
||||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
||||||
|
|
||||||
|
|
@ -503,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
|
||||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def requestDone(self, request):
|
||||||
|
# Store the request for inspection.
|
||||||
|
self.request = request
|
||||||
|
super().requestDone(request)
|
||||||
|
|
||||||
|
|
||||||
class _PullToPushProducer:
|
class _PullToPushProducer:
|
||||||
"""A push producer that wraps a pull producer."""
|
"""A push producer that wraps a pull producer."""
|
||||||
|
|
@ -590,6 +581,8 @@ class FakeRedisPubSubServer:
|
||||||
class FakeRedisPubSubProtocol(Protocol):
|
class FakeRedisPubSubProtocol(Protocol):
|
||||||
"""A connection from a client talking to the fake Redis server."""
|
"""A connection from a client talking to the fake Redis server."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[FakeTransport]
|
||||||
|
|
||||||
def __init__(self, server: FakeRedisPubSubServer):
|
def __init__(self, server: FakeRedisPubSubServer):
|
||||||
self._server = server
|
self._server = server
|
||||||
self._reader = hiredis.Reader()
|
self._reader = hiredis.Reader()
|
||||||
|
|
@ -634,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
|
|
||||||
def send(self, msg):
|
def send(self, msg):
|
||||||
"""Send a message back to the client."""
|
"""Send a message back to the client."""
|
||||||
|
assert self.transport is not None
|
||||||
|
|
||||||
raw = self.encode(msg).encode("utf-8")
|
raw = self.encode(msg).encode("utf-8")
|
||||||
|
|
||||||
self.transport.write(raw)
|
self.transport.write(raw)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import mock
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
from synapse.replication.tcp.commands import FederationAckCommand
|
from synapse.replication.tcp.commands import FederationAckCommand
|
||||||
from synapse.replication.tcp.protocol import AbstractConnection
|
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||||
from synapse.replication.tcp.streams.federation import FederationStream
|
from synapse.replication.tcp.streams.federation import FederationStream
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
rch = self.hs.get_tcp_replication()
|
rch = self.hs.get_tcp_replication()
|
||||||
|
|
||||||
# wire up the ReplicationCommandHandler to a mock connection
|
# wire up the ReplicationCommandHandler to a mock connection, which needs
|
||||||
mock_connection = mock.Mock(spec=AbstractConnection)
|
# to implement IReplicationConnection. (Note that Mock doesn't understand
|
||||||
|
# interfaces, but casing an interface to a list gives the attributes.)
|
||||||
|
mock_connection = mock.Mock(spec=list(IReplicationConnection))
|
||||||
rch.new_connection(mock_connection)
|
rch.new_connection(mock_connection)
|
||||||
|
|
||||||
# tell it it received an RDATA row
|
# tell it it received an RDATA row
|
||||||
|
|
|
||||||
|
|
@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
expected_flows = [
|
expected_flow_types = [
|
||||||
{"type": "m.login.cas"},
|
"m.login.cas",
|
||||||
{"type": "m.login.sso"},
|
"m.login.sso",
|
||||||
{"type": "m.login.token"},
|
"m.login.token",
|
||||||
{"type": "m.login.password"},
|
"m.login.password",
|
||||||
] + ADDITIONAL_LOGIN_FLOWS
|
] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
|
||||||
|
|
||||||
self.assertCountEqual(channel.json_body["flows"], expected_flows)
|
self.assertCountEqual(
|
||||||
|
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
|
||||||
|
)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
def test_get_msc2858_login_flows(self):
|
def test_get_msc2858_login_flows(self):
|
||||||
|
|
@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
|
||||||
def test_client_idp_redirect_msc2858_disabled(self):
|
|
||||||
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
|
|
||||||
channel = self._make_sso_redirect_request(True, "oidc")
|
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
|
||||||
def test_client_idp_redirect_to_unknown(self):
|
def test_client_idp_redirect_to_unknown(self):
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
channel = self._make_sso_redirect_request(True, "xxx")
|
channel = self._make_sso_redirect_request(False, "xxx")
|
||||||
self.assertEqual(channel.code, 404, channel.result)
|
self.assertEqual(channel.code, 404, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
|
||||||
def test_client_idp_redirect_to_oidc(self):
|
def test_client_idp_redirect_to_oidc(self):
|
||||||
"""If the client pick a known IdP, redirect to it"""
|
"""If the client pick a known IdP, redirect to it"""
|
||||||
|
channel = self._make_sso_redirect_request(False, "oidc")
|
||||||
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
|
||||||
|
# it should redirect us to the auth page of the OIDC server
|
||||||
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
|
def test_client_msc2858_redirect_to_oidc(self):
|
||||||
|
"""Test the unstable API"""
|
||||||
channel = self._make_sso_redirect_request(True, "oidc")
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
|
@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
# it should redirect us to the auth page of the OIDC server
|
# it should redirect us to the auth page of the OIDC server
|
||||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
|
def test_client_idp_redirect_msc2858_disabled(self):
|
||||||
|
"""If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
|
||||||
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
||||||
|
|
||||||
def _make_sso_redirect_request(
|
def _make_sso_redirect_request(
|
||||||
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
|
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(test_body, body)
|
self.assertEqual(test_body, body)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(slots=True, frozen=True)
|
||||||
class _TestImage:
|
class _TestImage:
|
||||||
"""An image for testing thumbnailing with the expected results
|
"""An image for testing thumbnailing with the expected results
|
||||||
|
|
||||||
|
|
@ -117,13 +117,15 @@ class _TestImage:
|
||||||
test should just check for success.
|
test should just check for success.
|
||||||
expected_scaled: The expected bytes from scaled thumbnailing, or None if
|
expected_scaled: The expected bytes from scaled thumbnailing, or None if
|
||||||
test should just check for a valid image returned.
|
test should just check for a valid image returned.
|
||||||
|
expected_found: True if the file should exist on the server, or False if
|
||||||
|
a 404 is expected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data = attr.ib(type=bytes)
|
data = attr.ib(type=bytes)
|
||||||
content_type = attr.ib(type=bytes)
|
content_type = attr.ib(type=bytes)
|
||||||
extension = attr.ib(type=bytes)
|
extension = attr.ib(type=bytes)
|
||||||
expected_cropped = attr.ib(type=Optional[bytes])
|
expected_cropped = attr.ib(type=Optional[bytes], default=None)
|
||||||
expected_scaled = attr.ib(type=Optional[bytes])
|
expected_scaled = attr.ib(type=Optional[bytes], default=None)
|
||||||
expected_found = attr.ib(default=True, type=bool)
|
expected_found = attr.ib(default=True, type=bool)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,6 +155,21 @@ class _TestImage:
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
# small png with transparency.
|
||||||
|
(
|
||||||
|
_TestImage(
|
||||||
|
unhexlify(
|
||||||
|
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
||||||
|
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
||||||
|
b"4154789c636800000082008177cd72b60000000049454e44ae426"
|
||||||
|
b"082"
|
||||||
|
),
|
||||||
|
b"image/png",
|
||||||
|
b".png",
|
||||||
|
# Note that we don't check the output since it varies across
|
||||||
|
# different versions of Pillow.
|
||||||
|
),
|
||||||
|
),
|
||||||
# small lossless webp
|
# small lossless webp
|
||||||
(
|
(
|
||||||
_TestImage(
|
_TestImage(
|
||||||
|
|
@ -162,8 +179,6 @@ class _TestImage:
|
||||||
),
|
),
|
||||||
b"image/webp",
|
b"image/webp",
|
||||||
b".webp",
|
b".webp",
|
||||||
None,
|
|
||||||
None,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
# an empty file
|
# an empty file
|
||||||
|
|
@ -172,9 +187,7 @@ class _TestImage:
|
||||||
b"",
|
b"",
|
||||||
b"image/gif",
|
b"image/gif",
|
||||||
b".gif",
|
b".gif",
|
||||||
None,
|
expected_found=False,
|
||||||
None,
|
|
||||||
False,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTCP,
|
IReactorTCP,
|
||||||
IResolverSimple,
|
IResolverSimple,
|
||||||
|
ITransport,
|
||||||
)
|
)
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||||
|
|
@ -188,7 +189,7 @@ class FakeSite:
|
||||||
|
|
||||||
def make_request(
|
def make_request(
|
||||||
reactor,
|
reactor,
|
||||||
site: Site,
|
site: Union[Site, FakeSite],
|
||||||
method,
|
method,
|
||||||
path,
|
path,
|
||||||
content=b"",
|
content=b"",
|
||||||
|
|
@ -467,6 +468,7 @@ def get_clock():
|
||||||
return clock, hs_clock
|
return clock, hs_clock
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(ITransport)
|
||||||
@attr.s(cmp=False)
|
@attr.s(cmp=False)
|
||||||
class FakeTransport:
|
class FakeTransport:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
|
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
|
||||||
self.assertTrue(r == [room2] or r == [room3])
|
self.assertTrue(r == [room2] or r == [room3])
|
||||||
|
|
||||||
@parameterized.expand([(True,), (False,)])
|
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
|
||||||
def test_auth_difference(self, use_chain_cover_index: bool):
|
|
||||||
room_id = "@ROOM:local"
|
room_id = "@ROOM:local"
|
||||||
|
|
||||||
# The silly auth graph we use to test the auth difference algorithm,
|
# The silly auth graph we use to test the auth difference algorithm,
|
||||||
|
|
@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"j": 1,
|
"j": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mark the room as not having a cover index
|
# Mark the room as maybe having a cover index.
|
||||||
|
|
||||||
def store_room(txn):
|
def store_room(txn):
|
||||||
self.store.db_pool.simple_insert_txn(
|
self.store.db_pool.simple_insert_txn(
|
||||||
|
|
@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return room_id
|
||||||
|
|
||||||
|
@parameterized.expand([(True,), (False,)])
|
||||||
|
def test_auth_chain_ids(self, use_chain_cover_index: bool):
|
||||||
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
|
# a and b have the same auth chain.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["a", "b"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
# d and e have the same auth chain.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
|
||||||
|
self.assertEqual(auth_chain_ids, ["k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
|
||||||
|
self.assertEqual(auth_chain_ids, ["j"])
|
||||||
|
|
||||||
|
# j and k have no parents.
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
|
||||||
|
self.assertEqual(auth_chain_ids, [])
|
||||||
|
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
|
||||||
|
self.assertEqual(auth_chain_ids, [])
|
||||||
|
|
||||||
|
# More complex input sequences.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["h", "i"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["k", "j"])
|
||||||
|
|
||||||
|
# e gets returned even though include_given is false, but it is in the
|
||||||
|
# auth chain of b.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["b", "e"])
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
|
||||||
|
|
||||||
|
# Test include_given.
|
||||||
|
auth_chain_ids = self.get_success(
|
||||||
|
self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
|
||||||
|
)
|
||||||
|
self.assertCountEqual(auth_chain_ids, ["i", "j"])
|
||||||
|
|
||||||
|
@parameterized.expand([(True,), (False,)])
|
||||||
|
def test_auth_difference(self, use_chain_cover_index: bool):
|
||||||
|
room_id = self._setup_auth_chain(use_chain_cover_index)
|
||||||
|
|
||||||
# Now actually test that various combinations give the right result:
|
# Now actually test that various combinations give the right result:
|
||||||
|
|
||||||
difference = self.get_success(
|
difference = self.get_success(
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from synapse.api.errors import NotFoundError, SynapseError
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError
|
|
||||||
from synapse.rest.client.v1 import room
|
from synapse.rest.client.v1 import room
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
def test_purge(self):
|
self.store = hs.get_datastore()
|
||||||
|
self.storage = self.hs.get_storage()
|
||||||
|
|
||||||
|
def test_purge_history(self):
|
||||||
"""
|
"""
|
||||||
Purging a room will delete everything before the topological point.
|
Purging a room history will delete everything before the topological point.
|
||||||
"""
|
"""
|
||||||
# Send four messages to the room
|
# Send four messages to the room
|
||||||
first = self.helper.send(self.room_id, body="test1")
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
|
@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
|
||||||
third = self.helper.send(self.room_id, body="test3")
|
third = self.helper.send(self.room_id, body="test3")
|
||||||
last = self.helper.send(self.room_id, body="test4")
|
last = self.helper.send(self.room_id, body="test4")
|
||||||
|
|
||||||
store = self.hs.get_datastore()
|
|
||||||
storage = self.hs.get_storage()
|
|
||||||
|
|
||||||
# Get the topological token
|
# Get the topological token
|
||||||
token = self.get_success(
|
token = self.get_success(
|
||||||
store.get_topological_token_for_event(last["event_id"])
|
self.store.get_topological_token_for_event(last["event_id"])
|
||||||
)
|
)
|
||||||
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
|
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
self.get_success(
|
self.get_success(
|
||||||
storage.purge_events.purge_history(self.room_id, token_str, True)
|
self.storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||||
# and last is not.
|
# and last is not.
|
||||||
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
|
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
|
||||||
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
|
self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
|
||||||
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
|
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
|
||||||
self.get_success(store.get_event(last["event_id"]))
|
self.get_success(self.store.get_event(last["event_id"]))
|
||||||
|
|
||||||
def test_purge_wont_delete_extrems(self):
|
def test_purge_history_wont_delete_extrems(self):
|
||||||
"""
|
"""
|
||||||
Purging a room will delete everything before the topological point.
|
Purging a room history will delete everything before the topological point.
|
||||||
"""
|
"""
|
||||||
# Send four messages to the room
|
# Send four messages to the room
|
||||||
first = self.helper.send(self.room_id, body="test1")
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
|
@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
|
||||||
third = self.helper.send(self.room_id, body="test3")
|
third = self.helper.send(self.room_id, body="test3")
|
||||||
last = self.helper.send(self.room_id, body="test4")
|
last = self.helper.send(self.room_id, body="test4")
|
||||||
|
|
||||||
storage = self.hs.get_datastore()
|
|
||||||
|
|
||||||
# Set the topological token higher than it should be
|
# Set the topological token higher than it should be
|
||||||
token = self.get_success(
|
token = self.get_success(
|
||||||
storage.get_topological_token_for_event(last["event_id"])
|
self.store.get_topological_token_for_event(last["event_id"])
|
||||||
)
|
)
|
||||||
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
|
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
|
f = self.get_failure(
|
||||||
self.pump()
|
self.storage.purge_events.purge_history(self.room_id, event, True),
|
||||||
f = self.failureResultOf(purge)
|
SynapseError,
|
||||||
|
)
|
||||||
self.assertIn("greater than forward", f.value.args[0])
|
self.assertIn("greater than forward", f.value.args[0])
|
||||||
|
|
||||||
# Try and get the events
|
# Try and get the events
|
||||||
self.get_success(storage.get_event(first["event_id"]))
|
self.get_success(self.store.get_event(first["event_id"]))
|
||||||
self.get_success(storage.get_event(second["event_id"]))
|
self.get_success(self.store.get_event(second["event_id"]))
|
||||||
self.get_success(storage.get_event(third["event_id"]))
|
self.get_success(self.store.get_event(third["event_id"]))
|
||||||
self.get_success(storage.get_event(last["event_id"]))
|
self.get_success(self.store.get_event(last["event_id"]))
|
||||||
|
|
||||||
|
def test_purge_room(self):
|
||||||
|
"""
|
||||||
|
Purging a room will delete everything about it.
|
||||||
|
"""
|
||||||
|
# Send four messages to the room
|
||||||
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
|
||||||
|
# Get the current room state.
|
||||||
|
state_handler = self.hs.get_state_handler()
|
||||||
|
create_event = self.get_success(
|
||||||
|
state_handler.get_current_state(self.room_id, "m.room.create", "")
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(create_event)
|
||||||
|
|
||||||
|
# Purge everything before this topological token
|
||||||
|
self.get_success(self.storage.purge_events.purge_room(self.room_id))
|
||||||
|
|
||||||
|
# The events aren't found.
|
||||||
|
self.store._invalidate_get_event_cache(create_event.event_id)
|
||||||
|
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
|
||||||
|
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
log_entry = self.format(record)
|
log_entry = self.format(record)
|
||||||
log_level = record.levelname.lower().replace("warning", "warn")
|
log_level = record.levelname.lower().replace("warning", "warn")
|
||||||
self.tx_log.emit(
|
self.tx_log.emit( # type: ignore
|
||||||
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
|
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue